Compare commits

...

83 Commits

Author SHA1 Message Date
Extraltodeus
ccfd4088d0
Merge branch 'Comfy-Org:master' into master 2026-01-14 01:19:29 +01:00
Alexander Piskun
1419047fdb
[Api Nodes]: Improve Price Badge Declarations (#11582)
* api nodes: price badges moved to nodes code

* added price badges for 4 more node-packs

* added price badges for 10 more node-packs

* added new price badges for Omni STD mode

* add support for autogrow groups

* use full names for "widgets", "inputs" and "groups"

* add strict typing for JSONata rules

* add price badge for WanReferenceVideoApi node

* add support for DynamicCombo

* sync price badges changes (https://github.com/Comfy-Org/ComfyUI_frontend/pull/7900)

* sync badges for Vidu2 nodes

* fixed incorrect price for RecraftCrispUpscaleNode

* fixed incorrect price badges for LTXV nodes

* fixed price badge for MinimaxHailuoVideoNode

* fixed price badges for PixVerse nodes
2026-01-13 16:18:28 -08:00
ric-yu
79f6bb5e4f
add blueprints dir for built-in blueprints (#11853) 2026-01-13 16:14:40 -08:00
Jukka Seppänen
e4b4fb3479
Load metadata on VAELoader (#11846)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Needed to load the proper LTX2 VAE if separated from checkpoint
2026-01-13 17:37:21 -05:00
Acly
d9dc02a7d6
Support "lite" version of alibaba-pai Z-Image Controlnet (#11849)
* reduced number of control layers (3) compared to full model
2026-01-13 15:03:53 -05:00
Alexander Piskun
c543ad81c3
fix(api-nodes-gemini): raise exception when no candidates due to safety block (#11848) 2026-01-13 08:30:13 -08:00
comfyanonymous
5ac1372533 ComfyUI v0.9.1
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
2026-01-13 01:44:06 -05:00
comfyanonymous
1dcbd9efaf
Bump ltxav mem estimation a bit. (#11842) 2026-01-13 01:42:07 -05:00
comfyanonymous
db9e6edfa1 ComfyUI v0.9.0 2026-01-13 01:23:31 -05:00
Christian Byrne
8af13b439b
Update requirements.txt (#11841) 2026-01-13 01:22:25 -05:00
Jedrzej Kosinski
acd0e53653
Make bulk_ops not use .returning to be compatible with python 3.10 and 3.11 sqlalchemy (#11839) 2026-01-13 00:15:24 -05:00
comfyanonymous
117e7a5853
Refactor to try to lower mem usage. (#11840) 2026-01-12 21:01:52 -08:00
comfyanonymous
b3c0e4de57
Make loras work on nvfp4 models. (#11837)
The initial applying is a bit slow but will probably be sped up in the
future.
2026-01-12 22:33:54 -05:00
ComfyUI Wiki
ecaeeb990d
chore: update workflow templates to v0.8.4 (#11835) 2026-01-12 19:18:01 -08:00
ComfyUI Wiki
c2b65e2fce
Update workflow templates to v0.8.0 (#11828)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-01-12 17:29:25 -05:00
Jukka Seppänen
fd5c0755af
Reduce LTX2 VRAM use by more efficient timestep embed handling (#11829) 2026-01-12 17:28:59 -05:00
comfyanonymous
c881a1d689
Support the siglip 2 naflex model as a clip vision model. (#11831)
Not useful yet.
2026-01-12 17:05:54 -05:00
kelseyee
a3b5d4996a
Support ModelScope-Trainer DiffSynth lora for Z Image. (#11805) 2026-01-12 15:38:46 -05:00
comfyanonymous
c6238047ee
Put more details about portable in readme. (#11816)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-01-11 21:11:53 -05:00
Alexander Piskun
5cd1113236
fix(api-nodes): use a unique name for uploading audio files (#11778)
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Generate Pydantic Stubs from api.comfy.org / generate-models (push) Has been cancelled
2026-01-11 03:07:11 -08:00
comfyanonymous
2f642d5d9b
Fix chroma fp8 te being treated as fp16. (#11795)
Some checks are pending
Python Linting / Run Pylint (push) Waiting to run
Python Linting / Run Ruff (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
2026-01-10 14:40:42 -08:00
comfyanonymous
cd912963f1
Fix issue with t5 text encoder in fp4. (#11794) 2026-01-10 17:31:31 -05:00
DELUXA
6e4b1f9d00
pythorch_attn_by_def_on_gfx1200 (#11793) 2026-01-10 16:51:05 -05:00
comfyanonymous
dc202a2e51
Properly save mixed ops. (#11772)
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
2026-01-10 02:03:57 -05:00
ComfyUI Wiki
153bc524bf
chore: update embedded docs to v0.4.0 (#11776) 2026-01-10 01:29:30 -05:00
Alexander Piskun
393d2880dd
feat(api-nodes): added nodes for Vidu2 (#11760)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
2026-01-09 12:59:38 -08:00
Alexander Piskun
4484b93d61
fix(api-nodes): do not downscale the input image for Topaz Enhance (#11768) 2026-01-09 12:25:56 -08:00
comfyanonymous
bd0e6825e8
Be less strict when loading mixed ops weights. (#11769) 2026-01-09 14:21:06 -05:00
Jedrzej Kosinski
ec0a832acb
Add workaround for hacky nodepack(s) that edit folder_names_and_paths to have values with tuples of more than 2. Other things could potentially break with those nodepack(s), so I will hunt for the guilty nodepack(s) now. (#11755)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-01-08 22:49:12 -08:00
ric-yu
04c49a29b4
feat: add cancelled filter to /jobs (#11680) 2026-01-08 21:57:36 -08:00
Terry Jia
4609fcd260
add node - image compare (#11343)
Some checks failed
Python Linting / Run Pylint (push) Waiting to run
Python Linting / Run Ruff (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
2026-01-08 21:31:19 -08:00
rattus
6207f86c18
Fix VAEEncodeForInpaint to support WAN VAE tuple downscale_ratio (#11572)
Use vae.spacial_compression_encode() instead of directly accessing
downscale_ratio to handle both standard VAEs (int) and WAN VAEs (tuple).

Addresses reviewer feedback on PR #11259.

Co-authored-by: ChrisFab16 <christopher@fabritius.dk>
2026-01-08 23:34:48 -05:00
Jedrzej Kosinski
1dc3da6314
Add most basic Asset support for models (#11315)
* Brought over minimal elements from PR 10045 to reproduce seed_assets and register_assets_system without adding anything to the DB or server routes yet, for now making everything sync (can introduce async once everything is cleaned up and brought over)

* Added db script to insert assets stuff, cleaned up some code; assets (models) now get added/rescanned

* Added support for 5 http endpoints for assets

* Replaced Optional with | None in schemas_in.py and schemas_out.py

* Remove two routes that will not be relevant yet in this PR: HEAD /api/assets/hash/<hash> and PUT /api/assets/<id>/preview

* Remove some functions the two deleted endpoints were using

* Don't show assets scan message upon calling /object_info endpoint

* removed unsued import to satisfy ruff

* Simplified hashing function tpye hint and _hash_file_obj

* Satisfied ruff
2026-01-08 22:21:51 -05:00
Comfy Org PR Bot
114fc73685
Bump comfyui-frontend-package to 1.36.13 (#11645) 2026-01-08 22:16:15 -05:00
comfyanonymous
b48d6a83d4
Fix csp error in frontend when forcing offline. (#11749) 2026-01-08 22:15:50 -05:00
Jukka Seppänen
027042db68
Add node: JoinAudioChannels (#11728) 2026-01-08 22:14:06 -05:00
comfyanonymous
1a20656448
Fix import issue. (#11746) 2026-01-08 17:23:59 -05:00
comfyanonymous
0f11869d55
Better detection if AMD torch compiled with efficient attention. (#11745) 2026-01-08 17:16:58 -05:00
Dr.Lt.Data
5943fbf457
bump comfyui_manager version to the 4.0.5 (#11732)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-01-08 08:15:42 -08:00
Yoland Yan
a60b7b86c5
Revert "Force sequential execution in CI test jobs (#11687)" (#11725)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
This reverts commit ce0000c4f2.
2026-01-07 21:41:57 -08:00
comfyanonymous
2e9d51680a ComfyUI version v0.8.2
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-01-07 23:50:02 -05:00
comfyanonymous
50d6e1caf4
Tweak ltxv vae mem estimation. (#11722) 2026-01-07 23:07:05 -05:00
comfyanonymous
ac12f77bed ComfyUI version v0.8.1 2026-01-07 22:10:08 -05:00
ComfyUI Wiki
fcd9a236b0
Update template to 0.7.69 (#11719) 2026-01-07 18:22:23 -08:00
comfyanonymous
21e8425087
Add warning for old pytorch. (#11718) 2026-01-07 21:07:26 -05:00
rattus
b6c79a648a
ops: Fix offloading with FP8MM performance (#11697)
This logic was checking comfy_cast_weights, and going straight to
to the forward_comfy_cast_weights implementation without
attempting to downscale input to fp8 in the event comfy_cast_weights
is set.

The main reason comfy_cast_weights would be set would be for async
offload, which is not a good reason to nix FP8MM.

So instead, and together the underlying exclusions for FP8MM which
are:

* having a weight_function (usually LowVramPatch)
* force_cast_weights (compute dtype override)
* the weight is not Quantized
* the input is already quantized
* the model or layer has MM explictily disabled.

If you get past all of those exclusions, quantize the input tensor.
Then hand the new input, quantized or not off to
forward_comfy_cast_weights to handle it. If the weight is offloaded
but input is quantized you will get an offloaded MM8.
2026-01-07 21:01:16 -05:00
comfyanonymous
25bc1b5b57
Add memory estimation function to ltxav text encoder. (#11716) 2026-01-07 20:11:22 -05:00
comfyanonymous
3cd19e99c1
Increase ltxav mem estimation by a bit. (#11715) 2026-01-07 20:04:56 -05:00
comfyanonymous
007b87e7ac
Bump required comfy-kitchen version. (#11714) 2026-01-07 19:48:47 -05:00
comfyanonymous
34751fe9f9
Lower ltxv text encoder vram use. (#11713) 2026-01-07 19:12:15 -05:00
Jukka Seppänen
1c705f7bfb
Add device selection for LTXAVTextEncoderLoader (#11700) 2026-01-07 18:39:59 -05:00
rattus
48e5ea1dfd
model_patcher: Remove confusing load stat (#11710)
If the loader passes 1e32 as the usable memory size, it means force
the full load. This happens with CPU loads and a few other misc cases.
Removing the confusing number and just leave the other details.
2026-01-07 18:39:20 -05:00
comfyanonymous
3cd7b32f1b
Support gemma 12B with quant weights. (#11696)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-01-07 05:15:14 -05:00
comfyanonymous
c0c9720d77
Fix stable release workflow not pulling latest comfy kitchen. (#11695) 2026-01-07 04:48:28 -05:00
comfyanonymous
fc0cb10bcb ComfyUI v0.8.0 2026-01-07 04:07:31 -05:00
comfyanonymous
b7d7cc1d49
Fix fp8 fast issue. (#11688) 2026-01-07 01:39:06 -05:00
Alexander Piskun
79e94544bd
feat(api-nodes): add WAN2.6 ReferenceToVideo (#11644) 2026-01-06 22:04:50 -08:00
Yoland Yan
ce0000c4f2
Force sequential execution in CI test jobs (#11687)
Added max-parallel setting to enforce sequential execution in test jobs.
2026-01-07 00:57:31 -05:00
comfyanonymous
c5cfb34c07
Update comfy-kitchen version to 0.2.3 (#11685) 2026-01-06 23:51:45 -05:00
comfyanonymous
edee33f55e
Disable comfy kitchen cuda if pytorch cuda less than 13 (#11681)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-01-06 22:13:43 -05:00
comfyanonymous
2c03884f5f
Skip fp4 matrix mult on devices that don't support it. (#11677) 2026-01-06 18:07:26 -05:00
comfyanonymous
6e9ee55cdd
Disable ltxav previews. (#11676) 2026-01-06 17:41:27 -05:00
comfyanonymous
023cf13721
Fix lowvram issue with ltxv2 text encoder. (#11675) 2026-01-06 17:33:03 -05:00
ComfyUI Wiki
c3566c0d76
chore: update workflow templates to v0.7.67 (#11667) 2026-01-06 14:28:29 -08:00
comfyanonymous
c3c3e93c5b
Use rope functions from comfy kitchen. (#11674) 2026-01-06 16:57:50 -05:00
comfyanonymous
6ffc159bdd
Update comfy-kitchen version to 0.2.1 (#11672) 2026-01-06 15:53:43 -05:00
comfyanonymous
96e0d0924e
Add helpful message to portable. (#11671)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-01-06 14:43:24 -05:00
ComfyUI Wiki
e14f3b6610
chore: update workflow templates to v0.7.66 (#11652)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-01-05 22:37:11 -08:00
comfyanonymous
1618002411
Revert "Use rope functions from comfy kitchen. (#11647)" (#11648)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
This reverts commit 6ef85c4915.
2026-01-05 23:07:39 -05:00
comfyanonymous
6ef85c4915
Use rope functions from comfy kitchen. (#11647) 2026-01-05 22:50:35 -05:00
comfyanonymous
6da00dd899
Initial ops changes to use comfy_kitchen: Initial nvfp4 checkpoint support. (#11635)
---------

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-01-05 21:48:58 -05:00
comfyanonymous
4f3f9e72a9
Fix name. (#11638)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-01-05 02:41:23 -08:00
comfyanonymous
d157c3299d
Refactor module_size function. (#11637) 2026-01-05 03:48:31 -05:00
comfyanonymous
d1b9822f74
Add LTXAVTextEncoderLoader node. (#11634) 2026-01-05 02:27:31 -05:00
comfyanonymous
f2b002372b
Support the LTXV 2 model. (#11632) 2026-01-05 01:58:59 -05:00
comfyanonymous
38d0493825
Fix case where upscale model wouldn't be moved to cpu. (#11633)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-01-04 19:13:50 -05:00
Alexander Piskun
acbf08cd60
feat(api-nodes): add support for 720p resolution for Kling Omni nodes (#11604)
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Generate Pydantic Stubs from api.comfy.org / generate-models (push) Has been cancelled
2026-01-03 23:05:02 -08:00
comfyanonymous
53e762a3af
Print memory summary on OOM to help with debugging. (#11613) 2026-01-03 22:28:38 -05:00
comfyanonymous
9a552df898
Remove leftover scaled_fp8 key. (#11603)
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
2026-01-02 17:28:10 -08:00
Alexander Piskun
f2fda021ab
Tripo3D: pass face_limit parameter only when it differs from default (#11601)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-01-02 03:18:43 -08:00
throttlekitty
303b1735f8
Give Mahiro CFG a more appropriate display name (#11580)
Some checks are pending
Python Linting / Run Pylint (push) Waiting to run
Python Linting / Run Ruff (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-01-02 00:37:37 -08:00
Alexander Piskun
9e5f677746
Ignore all frames except the first one for MPO format. (#11569) 2026-01-02 00:35:34 -08:00
comfyanonymous
65cfcf5b1b
New Year ruff cleanup. (#11595) 2026-01-01 22:06:14 -05:00
109 changed files with 8999 additions and 1311 deletions

View File

@ -1,3 +1,3 @@
..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe
pause

View File

@ -1,3 +1,3 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe
pause

View File

@ -1,3 +1,3 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe
pause

View File

@ -117,7 +117,7 @@ jobs:
./python.exe get-pip.py
./python.exe -s -m pip install ../${{ inputs.cache_tag }}_python_deps/*
grep comfyui ../ComfyUI/requirements.txt > ./requirements_comfyui.txt
grep comfy ../ComfyUI/requirements.txt > ./requirements_comfyui.txt
./python.exe -s -m pip install -r requirements_comfyui.txt
rm requirements_comfyui.txt

View File

@ -18,7 +18,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}

View File

@ -32,7 +32,9 @@ jobs:
working-directory: ComfyUI
- name: Check for unhandled exceptions in server log
run: |
if grep -qE "Exception|Error" console_output.log; then
grep -v "Found comfy_kitchen backend triton: {'available': False, 'disabled': True, 'unavailable_reason': \"ImportError: No module named 'triton'\", 'capabilities': \[\]}" console_output.log | grep -v "Found comfy_kitchen backend triton: {'available': False, 'disabled': False, 'unavailable_reason': \"ImportError: No module named 'triton'\", 'capabilities': \[\]}" > console_output_filtered.log
cat console_output_filtered.log
if grep -qE "Exception|Error" console_output_filtered.log; then
echo "Unhandled exception/error found in server log."
exit 1
fi

View File

@ -183,7 +183,7 @@ Simply download, extract with [7-Zip](https://7-zip.org) or with the windows exp
If you have trouble extracting it, right click the file -> properties -> unblock
Update your Nvidia drivers if it doesn't start.
The portable above currently comes with python 3.13 and pytorch cuda 13.0. Update your Nvidia drivers if it doesn't start.
#### Alternative Downloads:
@ -212,7 +212,7 @@ Python 3.14 works but you may encounter issues with the torch compile node. The
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch unless it is less than 2 weeks old.
torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old.
### Instructions:

View File

@ -0,0 +1,174 @@
"""
Initial assets schema
Revision ID: 0001_assets
Revises: None
Create Date: 2025-12-10 00:00:00
"""
from alembic import op
import sqlalchemy as sa
revision = "0001_assets"
down_revision = None
branch_labels = None
depends_on = None
def upgrade() -> None:
# ASSETS: content identity
op.create_table(
"assets",
sa.Column("id", sa.String(length=36), primary_key=True),
sa.Column("hash", sa.String(length=256), nullable=True),
sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"),
sa.Column("mime_type", sa.String(length=255), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
)
op.create_index("uq_assets_hash", "assets", ["hash"], unique=True)
op.create_index("ix_assets_mime_type", "assets", ["mime_type"])
# ASSETS_INFO: user-visible references
op.create_table(
"assets_info",
sa.Column("id", sa.String(length=36), primary_key=True),
sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""),
sa.Column("name", sa.String(length=512), nullable=False),
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False),
sa.Column("preview_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="SET NULL"), nullable=True),
sa.Column("user_metadata", sa.JSON(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False),
sa.UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
)
op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"])
op.create_index("ix_assets_info_asset_id", "assets_info", ["asset_id"])
op.create_index("ix_assets_info_name", "assets_info", ["name"])
op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"])
op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"])
op.create_index("ix_assets_info_owner_name", "assets_info", ["owner_id", "name"])
# TAGS: normalized tag vocabulary
op.create_table(
"tags",
sa.Column("name", sa.String(length=512), primary_key=True),
sa.Column("tag_type", sa.String(length=32), nullable=False, server_default="user"),
sa.CheckConstraint("name = lower(name)", name="ck_tags_lowercase"),
)
op.create_index("ix_tags_tag_type", "tags", ["tag_type"])
# ASSET_INFO_TAGS: many-to-many for tags on AssetInfo
op.create_table(
"asset_info_tags",
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False),
sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"),
sa.Column("added_at", sa.DateTime(timezone=False), nullable=False),
sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"),
)
op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"])
op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"])
# ASSET_CACHE_STATE: N:1 local cache rows per Asset
op.create_table(
"asset_cache_state",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="CASCADE"), nullable=False),
sa.Column("file_path", sa.Text(), nullable=False), # absolute local path to cached file
sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
sa.Column("needs_verify", sa.Boolean(), nullable=False, server_default=sa.text("false")),
sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
)
op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"])
op.create_index("ix_asset_cache_state_asset_id", "asset_cache_state", ["asset_id"])
# ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting
op.create_table(
"asset_info_meta",
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
sa.Column("key", sa.String(length=256), nullable=False),
sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"),
sa.Column("val_str", sa.String(length=2048), nullable=True),
sa.Column("val_num", sa.Numeric(38, 10), nullable=True),
sa.Column("val_bool", sa.Boolean(), nullable=True),
sa.Column("val_json", sa.JSON(), nullable=True),
sa.PrimaryKeyConstraint("asset_info_id", "key", "ordinal", name="pk_asset_info_meta"),
)
op.create_index("ix_asset_info_meta_key", "asset_info_meta", ["key"])
op.create_index("ix_asset_info_meta_key_val_str", "asset_info_meta", ["key", "val_str"])
op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"])
op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"])
# Tags vocabulary
tags_table = sa.table(
"tags",
sa.column("name", sa.String(length=512)),
sa.column("tag_type", sa.String()),
)
op.bulk_insert(
tags_table,
[
{"name": "models", "tag_type": "system"},
{"name": "input", "tag_type": "system"},
{"name": "output", "tag_type": "system"},
{"name": "configs", "tag_type": "system"},
{"name": "checkpoints", "tag_type": "system"},
{"name": "loras", "tag_type": "system"},
{"name": "vae", "tag_type": "system"},
{"name": "text_encoders", "tag_type": "system"},
{"name": "diffusion_models", "tag_type": "system"},
{"name": "clip_vision", "tag_type": "system"},
{"name": "style_models", "tag_type": "system"},
{"name": "embeddings", "tag_type": "system"},
{"name": "diffusers", "tag_type": "system"},
{"name": "vae_approx", "tag_type": "system"},
{"name": "controlnet", "tag_type": "system"},
{"name": "gligen", "tag_type": "system"},
{"name": "upscale_models", "tag_type": "system"},
{"name": "hypernetworks", "tag_type": "system"},
{"name": "photomaker", "tag_type": "system"},
{"name": "classifiers", "tag_type": "system"},
{"name": "encoder", "tag_type": "system"},
{"name": "decoder", "tag_type": "system"},
{"name": "missing", "tag_type": "system"},
{"name": "rescan", "tag_type": "system"},
],
)
def downgrade() -> None:
op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta")
op.drop_table("asset_info_meta")
op.drop_index("ix_asset_cache_state_asset_id", table_name="asset_cache_state")
op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state")
op.drop_constraint("uq_asset_cache_state_file_path", table_name="asset_cache_state")
op.drop_table("asset_cache_state")
op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags")
op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags")
op.drop_table("asset_info_tags")
op.drop_index("ix_tags_tag_type", table_name="tags")
op.drop_table("tags")
op.drop_constraint("uq_assets_info_asset_owner_name", table_name="assets_info")
op.drop_index("ix_assets_info_owner_name", table_name="assets_info")
op.drop_index("ix_assets_info_last_access_time", table_name="assets_info")
op.drop_index("ix_assets_info_created_at", table_name="assets_info")
op.drop_index("ix_assets_info_name", table_name="assets_info")
op.drop_index("ix_assets_info_asset_id", table_name="assets_info")
op.drop_index("ix_assets_info_owner_id", table_name="assets_info")
op.drop_table("assets_info")
op.drop_index("uq_assets_hash", table_name="assets")
op.drop_index("ix_assets_mime_type", table_name="assets")
op.drop_table("assets")

102
app/assets/api/routes.py Normal file
View File

@ -0,0 +1,102 @@
import logging
import uuid
from aiohttp import web
from pydantic import ValidationError
import app.assets.manager as manager
from app import user_manager
from app.assets.api import schemas_in
from app.assets.helpers import get_query_dict
ROUTES = web.RouteTableDef()
USER_MANAGER: user_manager.UserManager | None = None
# UUID regex (canonical hyphenated form, case-insensitive)
UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None:
global USER_MANAGER
USER_MANAGER = user_manager_instance
app.add_routes(ROUTES)
def _error_response(status: int, code: str, message: str, details: dict | None = None) -> web.Response:
return web.json_response({"error": {"code": code, "message": message, "details": details or {}}}, status=status)
def _validation_error_response(code: str, ve: ValidationError) -> web.Response:
return _error_response(400, code, "Validation failed.", {"errors": ve.json()})
@ROUTES.get("/api/assets")
async def list_assets(request: web.Request) -> web.Response:
"""
GET request to list assets.
"""
query_dict = get_query_dict(request)
try:
q = schemas_in.ListAssetsQuery.model_validate(query_dict)
except ValidationError as ve:
return _validation_error_response("INVALID_QUERY", ve)
payload = manager.list_assets(
include_tags=q.include_tags,
exclude_tags=q.exclude_tags,
name_contains=q.name_contains,
metadata_filter=q.metadata_filter,
limit=q.limit,
offset=q.offset,
sort=q.sort,
order=q.order,
owner_id=USER_MANAGER.get_request_user_id(request),
)
return web.json_response(payload.model_dump(mode="json"))
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
async def get_asset(request: web.Request) -> web.Response:
"""
GET request to get an asset's info as JSON.
"""
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
result = manager.get_asset(
asset_info_id=asset_info_id,
owner_id=USER_MANAGER.get_request_user_id(request),
)
except ValueError as e:
return _error_response(404, "ASSET_NOT_FOUND", str(e), {"id": asset_info_id})
except Exception:
logging.exception(
"get_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.get("/api/tags")
async def get_tags(request: web.Request) -> web.Response:
"""
GET request to list all tags based on query parameters.
"""
query_map = dict(request.rel_url.query)
try:
query = schemas_in.TagsListQuery.model_validate(query_map)
except ValidationError as e:
return web.json_response(
{"error": {"code": "INVALID_QUERY", "message": "Invalid query parameters", "details": e.errors()}},
status=400,
)
result = manager.list_tags(
prefix=query.prefix,
limit=query.limit,
offset=query.offset,
order=query.order,
include_zero=query.include_zero,
owner_id=USER_MANAGER.get_request_user_id(request),
)
return web.json_response(result.model_dump(mode="json"))

View File

@ -0,0 +1,94 @@
import json
import uuid
from typing import Any, Literal
from pydantic import (
BaseModel,
ConfigDict,
Field,
conint,
field_validator,
)
class ListAssetsQuery(BaseModel):
include_tags: list[str] = Field(default_factory=list)
exclude_tags: list[str] = Field(default_factory=list)
name_contains: str | None = None
# Accept either a JSON string (query param) or a dict
metadata_filter: dict[str, Any] | None = None
limit: conint(ge=1, le=500) = 20
offset: conint(ge=0) = 0
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at"
order: Literal["asc", "desc"] = "desc"
@field_validator("include_tags", "exclude_tags", mode="before")
@classmethod
def _split_csv_tags(cls, v):
# Accept "a,b,c" or ["a","b"] (we are liberal in what we accept)
if v is None:
return []
if isinstance(v, str):
return [t.strip() for t in v.split(",") if t.strip()]
if isinstance(v, list):
out: list[str] = []
for item in v:
if isinstance(item, str):
out.extend([t.strip() for t in item.split(",") if t.strip()])
return out
return v
@field_validator("metadata_filter", mode="before")
@classmethod
def _parse_metadata_json(cls, v):
if v is None or isinstance(v, dict):
return v
if isinstance(v, str) and v.strip():
try:
parsed = json.loads(v)
except Exception as e:
raise ValueError(f"metadata_filter must be JSON: {e}") from e
if not isinstance(parsed, dict):
raise ValueError("metadata_filter must be a JSON object")
return parsed
return None
class TagsListQuery(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
prefix: str | None = Field(None, min_length=1, max_length=256)
limit: int = Field(100, ge=1, le=1000)
offset: int = Field(0, ge=0, le=10_000_000)
order: Literal["count_desc", "name_asc"] = "count_desc"
include_zero: bool = True
@field_validator("prefix")
@classmethod
def normalize_prefix(cls, v: str | None) -> str | None:
if v is None:
return v
v = v.strip()
return v.lower() or None
class SetPreviewBody(BaseModel):
"""Set or clear the preview for an AssetInfo. Provide an Asset.id or null."""
preview_id: str | None = None
@field_validator("preview_id", mode="before")
@classmethod
def _norm_uuid(cls, v):
if v is None:
return None
s = str(v).strip()
if not s:
return None
try:
uuid.UUID(s)
except Exception:
raise ValueError("preview_id must be a UUID")
return s

View File

@ -0,0 +1,60 @@
from datetime import datetime
from typing import Any
from pydantic import BaseModel, ConfigDict, Field, field_serializer
class AssetSummary(BaseModel):
id: str
name: str
asset_hash: str | None = None
size: int | None = None
mime_type: str | None = None
tags: list[str] = Field(default_factory=list)
preview_url: str | None = None
created_at: datetime | None = None
updated_at: datetime | None = None
last_access_time: datetime | None = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("created_at", "updated_at", "last_access_time")
def _ser_dt(self, v: datetime | None, _info):
return v.isoformat() if v else None
class AssetsList(BaseModel):
assets: list[AssetSummary]
total: int
has_more: bool
class AssetDetail(BaseModel):
id: str
name: str
asset_hash: str | None = None
size: int | None = None
mime_type: str | None = None
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
preview_id: str | None = None
created_at: datetime | None = None
last_access_time: datetime | None = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("created_at", "last_access_time")
def _ser_dt(self, v: datetime | None, _info):
return v.isoformat() if v else None
class TagUsage(BaseModel):
name: str
count: int
type: str
class TagsList(BaseModel):
tags: list[TagUsage] = Field(default_factory=list)
total: int
has_more: bool

View File

@ -0,0 +1,204 @@
import os
import uuid
import sqlalchemy
from typing import Iterable
from sqlalchemy.orm import Session
from sqlalchemy.dialects import sqlite
from app.assets.helpers import utcnow
from app.assets.database.models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, AssetInfoMeta
MAX_BIND_PARAMS = 800
def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]:
if not rows:
return []
rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row))
for i in range(0, len(rows), rows_per_stmt):
yield rows[i:i + rows_per_stmt]
def _iter_chunks(seq, n: int):
for i in range(0, len(seq), n):
yield seq[i:i + n]
def _rows_per_stmt(cols: int) -> int:
return max(1, MAX_BIND_PARAMS // max(1, cols))
def seed_from_paths_batch(
session: Session,
*,
specs: list[dict],
owner_id: str = "",
) -> dict:
"""Each spec is a dict with keys:
- abs_path: str
- size_bytes: int
- mtime_ns: int
- info_name: str
- tags: list[str]
- fname: Optional[str]
"""
if not specs:
return {"inserted_infos": 0, "won_states": 0, "lost_states": 0}
now = utcnow()
asset_rows: list[dict] = []
state_rows: list[dict] = []
path_to_asset: dict[str, str] = {}
asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row
path_list: list[str] = []
for sp in specs:
ap = os.path.abspath(sp["abs_path"])
aid = str(uuid.uuid4())
iid = str(uuid.uuid4())
path_list.append(ap)
path_to_asset[ap] = aid
asset_rows.append(
{
"id": aid,
"hash": None,
"size_bytes": sp["size_bytes"],
"mime_type": None,
"created_at": now,
}
)
state_rows.append(
{
"asset_id": aid,
"file_path": ap,
"mtime_ns": sp["mtime_ns"],
}
)
asset_to_info[aid] = {
"id": iid,
"owner_id": owner_id,
"name": sp["info_name"],
"asset_id": aid,
"preview_id": None,
"user_metadata": {"filename": sp["fname"]} if sp["fname"] else None,
"created_at": now,
"updated_at": now,
"last_access_time": now,
"_tags": sp["tags"],
"_filename": sp["fname"],
}
# insert all seed Assets (hash=NULL)
ins_asset = sqlite.insert(Asset)
for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)):
session.execute(ins_asset, chunk)
# try to claim AssetCacheState (file_path)
# Insert with ON CONFLICT DO NOTHING, then query to find which paths were actually inserted
ins_state = (
sqlite.insert(AssetCacheState)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)):
session.execute(ins_state, chunk)
# Query to find which of our paths won (were actually inserted)
winners_by_path: set[str] = set()
for chunk in _iter_chunks(path_list, MAX_BIND_PARAMS):
result = session.execute(
sqlalchemy.select(AssetCacheState.file_path)
.where(AssetCacheState.file_path.in_(chunk))
.where(AssetCacheState.asset_id.in_([path_to_asset[p] for p in chunk]))
)
winners_by_path.update(result.scalars().all())
all_paths_set = set(path_list)
losers_by_path = all_paths_set - winners_by_path
lost_assets = [path_to_asset[p] for p in losers_by_path]
if lost_assets: # losers get their Asset removed
for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS):
session.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(id_chunk)))
if not winners_by_path:
return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)}
# insert AssetInfo only for winners
# Insert with ON CONFLICT DO NOTHING, then query to find which were actually inserted
winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path]
ins_info = (
sqlite.insert(AssetInfo)
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
)
for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)):
session.execute(ins_info, chunk)
# Query to find which info rows were actually inserted (by matching our generated IDs)
all_info_ids = [row["id"] for row in winner_info_rows]
inserted_info_ids: set[str] = set()
for chunk in _iter_chunks(all_info_ids, MAX_BIND_PARAMS):
result = session.execute(
sqlalchemy.select(AssetInfo.id).where(AssetInfo.id.in_(chunk))
)
inserted_info_ids.update(result.scalars().all())
# build and insert tag + meta rows for the AssetInfo
tag_rows: list[dict] = []
meta_rows: list[dict] = []
if inserted_info_ids:
for row in winner_info_rows:
iid = row["id"]
if iid not in inserted_info_ids:
continue
for t in row["_tags"]:
tag_rows.append({
"asset_info_id": iid,
"tag_name": t,
"origin": "automatic",
"added_at": now,
})
if row["_filename"]:
meta_rows.append(
{
"asset_info_id": iid,
"key": "filename",
"ordinal": 0,
"val_str": row["_filename"],
"val_num": None,
"val_bool": None,
"val_json": None,
}
)
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS)
return {
"inserted_infos": len(inserted_info_ids),
"won_states": len(winners_by_path),
"lost_states": len(losers_by_path),
}
def bulk_insert_tags_and_meta(
session: Session,
*,
tag_rows: list[dict],
meta_rows: list[dict],
max_bind_params: int,
) -> None:
"""Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING.
- tag_rows keys: asset_info_id, tag_name, origin, added_at
- meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json
"""
if tag_rows:
ins_links = (
sqlite.insert(AssetInfoTag)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params):
session.execute(ins_links, chunk)
if meta_rows:
ins_meta = (
sqlite.insert(AssetInfoMeta)
.on_conflict_do_nothing(
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
)
)
for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params):
session.execute(ins_meta, chunk)

View File

@ -0,0 +1,233 @@
from __future__ import annotations
import uuid
from datetime import datetime
from typing import Any
from sqlalchemy import (
JSON,
BigInteger,
Boolean,
CheckConstraint,
DateTime,
ForeignKey,
Index,
Integer,
Numeric,
String,
Text,
UniqueConstraint,
)
from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship
from app.assets.helpers import utcnow
from app.database.models import to_dict, Base
class Asset(Base):
__tablename__ = "assets"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
hash: Mapped[str | None] = mapped_column(String(256), nullable=True)
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
mime_type: Mapped[str | None] = mapped_column(String(255))
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=utcnow
)
infos: Mapped[list[AssetInfo]] = relationship(
"AssetInfo",
back_populates="asset",
primaryjoin=lambda: Asset.id == foreign(AssetInfo.asset_id),
foreign_keys=lambda: [AssetInfo.asset_id],
cascade="all,delete-orphan",
passive_deletes=True,
)
preview_of: Mapped[list[AssetInfo]] = relationship(
"AssetInfo",
back_populates="preview_asset",
primaryjoin=lambda: Asset.id == foreign(AssetInfo.preview_id),
foreign_keys=lambda: [AssetInfo.preview_id],
viewonly=True,
)
cache_states: Mapped[list[AssetCacheState]] = relationship(
back_populates="asset",
cascade="all, delete-orphan",
passive_deletes=True,
)
__table_args__ = (
Index("uq_assets_hash", "hash", unique=True),
Index("ix_assets_mime_type", "mime_type"),
CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
return to_dict(self, include_none=include_none)
def __repr__(self) -> str:
return f"<Asset id={self.id} hash={(self.hash or '')[:12]}>"
class AssetCacheState(Base):
__tablename__ = "asset_cache_state"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False)
file_path: Mapped[str] = mapped_column(Text, nullable=False)
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
asset: Mapped[Asset] = relationship(back_populates="cache_states")
__table_args__ = (
Index("ix_asset_cache_state_file_path", "file_path"),
Index("ix_asset_cache_state_asset_id", "asset_id"),
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
return to_dict(self, include_none=include_none)
def __repr__(self) -> str:
return f"<AssetCacheState id={self.id} asset_id={self.asset_id} path={self.file_path!r}>"
class AssetInfo(Base):
__tablename__ = "assets_info"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
name: Mapped[str] = mapped_column(String(512), nullable=False)
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False)
preview_id: Mapped[str | None] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL"))
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON(none_as_null=True))
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
asset: Mapped[Asset] = relationship(
"Asset",
back_populates="infos",
foreign_keys=[asset_id],
lazy="selectin",
)
preview_asset: Mapped[Asset | None] = relationship(
"Asset",
back_populates="preview_of",
foreign_keys=[preview_id],
)
metadata_entries: Mapped[list[AssetInfoMeta]] = relationship(
back_populates="asset_info",
cascade="all,delete-orphan",
passive_deletes=True,
)
tag_links: Mapped[list[AssetInfoTag]] = relationship(
back_populates="asset_info",
cascade="all,delete-orphan",
passive_deletes=True,
overlaps="tags,asset_infos",
)
tags: Mapped[list[Tag]] = relationship(
secondary="asset_info_tags",
back_populates="asset_infos",
lazy="selectin",
viewonly=True,
overlaps="tag_links,asset_info_links,asset_infos,tag",
)
__table_args__ = (
UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
Index("ix_assets_info_owner_name", "owner_id", "name"),
Index("ix_assets_info_owner_id", "owner_id"),
Index("ix_assets_info_asset_id", "asset_id"),
Index("ix_assets_info_name", "name"),
Index("ix_assets_info_created_at", "created_at"),
Index("ix_assets_info_last_access_time", "last_access_time"),
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
data = to_dict(self, include_none=include_none)
data["tags"] = [t.name for t in self.tags]
return data
def __repr__(self) -> str:
return f"<AssetInfo id={self.id} name={self.name!r} asset_id={self.asset_id}>"
class AssetInfoMeta(Base):
__tablename__ = "asset_info_meta"
asset_info_id: Mapped[str] = mapped_column(
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
)
key: Mapped[str] = mapped_column(String(256), primary_key=True)
ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0)
val_str: Mapped[str | None] = mapped_column(String(2048), nullable=True)
val_num: Mapped[float | None] = mapped_column(Numeric(38, 10), nullable=True)
val_bool: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
val_json: Mapped[Any | None] = mapped_column(JSON(none_as_null=True), nullable=True)
asset_info: Mapped[AssetInfo] = relationship(back_populates="metadata_entries")
__table_args__ = (
Index("ix_asset_info_meta_key", "key"),
Index("ix_asset_info_meta_key_val_str", "key", "val_str"),
Index("ix_asset_info_meta_key_val_num", "key", "val_num"),
Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"),
)
class AssetInfoTag(Base):
__tablename__ = "asset_info_tags"
asset_info_id: Mapped[str] = mapped_column(
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
)
tag_name: Mapped[str] = mapped_column(
String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True
)
origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual")
added_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=utcnow
)
asset_info: Mapped[AssetInfo] = relationship(back_populates="tag_links")
tag: Mapped[Tag] = relationship(back_populates="asset_info_links")
__table_args__ = (
Index("ix_asset_info_tags_tag_name", "tag_name"),
Index("ix_asset_info_tags_asset_info_id", "asset_info_id"),
)
class Tag(Base):
__tablename__ = "tags"
name: Mapped[str] = mapped_column(String(512), primary_key=True)
tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user")
asset_info_links: Mapped[list[AssetInfoTag]] = relationship(
back_populates="tag",
overlaps="asset_infos,tags",
)
asset_infos: Mapped[list[AssetInfo]] = relationship(
secondary="asset_info_tags",
back_populates="tags",
viewonly=True,
overlaps="asset_info_links,tag_links,tags,asset_info",
)
__table_args__ = (
Index("ix_tags_tag_type", "tag_type"),
)
def __repr__(self) -> str:
return f"<Tag {self.name}>"

View File

@ -0,0 +1,267 @@
import sqlalchemy as sa
from collections import defaultdict
from sqlalchemy import select, exists, func
from sqlalchemy.orm import Session, contains_eager, noload
from app.assets.database.models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
from app.assets.helpers import escape_like_prefix, normalize_tags
from typing import Sequence
def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
"""Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
owner_id = (owner_id or "").strip()
if owner_id == "":
return AssetInfo.owner_id == ""
return AssetInfo.owner_id.in_(["", owner_id])
def apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
) -> sa.sql.Select:
"""include_tags: every tag must be present; exclude_tags: none may be present."""
include_tags = normalize_tags(include_tags)
exclude_tags = normalize_tags(exclude_tags)
if include_tags:
for tag_name in include_tags:
stmt = stmt.where(
exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name == tag_name)
)
)
if exclude_tags:
stmt = stmt.where(
~exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name.in_(exclude_tags))
)
)
return stmt
def apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: dict | None = None,
) -> sa.sql.Select:
"""Apply filters using asset_info_meta projection table."""
if not metadata_filter:
return stmt
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
return sa.exists().where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
*preds,
)
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
if value is None:
no_row_for_key = sa.not_(
sa.exists().where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
)
)
null_row = _exists_for_pred(
key,
AssetInfoMeta.val_json.is_(None),
AssetInfoMeta.val_str.is_(None),
AssetInfoMeta.val_num.is_(None),
AssetInfoMeta.val_bool.is_(None),
)
return sa.or_(no_row_for_key, null_row)
if isinstance(value, bool):
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
if isinstance(value, (int, float)):
from decimal import Decimal
num = value if isinstance(value, Decimal) else Decimal(str(value))
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
if isinstance(value, str):
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
for k, v in metadata_filter.items():
if isinstance(v, list):
ors = [_exists_clause_for_value(k, elem) for elem in v]
if ors:
stmt = stmt.where(sa.or_(*ors))
else:
stmt = stmt.where(_exists_clause_for_value(k, v))
return stmt
def asset_exists_by_hash(session: Session, asset_hash: str) -> bool:
"""
Check if an asset with a given hash exists in database.
"""
row = (
session.execute(
select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1)
)
).first()
return row is not None
def get_asset_info_by_id(session: Session, asset_info_id: str) -> AssetInfo | None:
return session.get(AssetInfo, asset_info_id)
def list_asset_infos_page(
session: Session,
owner_id: str = "",
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
) -> tuple[list[AssetInfo], dict[str, list[str]], int]:
base = (
select(AssetInfo)
.join(Asset, Asset.id == AssetInfo.asset_id)
.options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags))
.where(visible_owner_clause(owner_id))
)
if name_contains:
escaped, esc = escape_like_prefix(name_contains)
base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
base = apply_tag_filters(base, include_tags, exclude_tags)
base = apply_metadata_filter(base, metadata_filter)
sort = (sort or "created_at").lower()
order = (order or "desc").lower()
sort_map = {
"name": AssetInfo.name,
"created_at": AssetInfo.created_at,
"updated_at": AssetInfo.updated_at,
"last_access_time": AssetInfo.last_access_time,
"size": Asset.size_bytes,
}
sort_col = sort_map.get(sort, AssetInfo.created_at)
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
base = base.order_by(sort_exp).limit(limit).offset(offset)
count_stmt = (
select(sa.func.count())
.select_from(AssetInfo)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(visible_owner_clause(owner_id))
)
if name_contains:
escaped, esc = escape_like_prefix(name_contains)
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
total = int((session.execute(count_stmt)).scalar_one() or 0)
infos = (session.execute(base)).unique().scalars().all()
id_list: list[str] = [i.id for i in infos]
tag_map: dict[str, list[str]] = defaultdict(list)
if id_list:
rows = session.execute(
select(AssetInfoTag.asset_info_id, Tag.name)
.join(Tag, Tag.name == AssetInfoTag.tag_name)
.where(AssetInfoTag.asset_info_id.in_(id_list))
)
for aid, tag_name in rows.all():
tag_map[aid].append(tag_name)
return infos, tag_map, total
def fetch_asset_info_asset_and_tags(
session: Session,
asset_info_id: str,
owner_id: str = "",
) -> tuple[AssetInfo, Asset, list[str]] | None:
stmt = (
select(AssetInfo, Asset, Tag.name)
.join(Asset, Asset.id == AssetInfo.asset_id)
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
.where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
.options(noload(AssetInfo.tags))
.order_by(Tag.name.asc())
)
rows = (session.execute(stmt)).all()
if not rows:
return None
first_info, first_asset, _ = rows[0]
tags: list[str] = []
seen: set[str] = set()
for _info, _asset, tag_name in rows:
if tag_name and tag_name not in seen:
seen.add(tag_name)
tags.append(tag_name)
return first_info, first_asset, tags
def list_tags_with_usage(
session: Session,
prefix: str | None = None,
limit: int = 100,
offset: int = 0,
include_zero: bool = True,
order: str = "count_desc",
owner_id: str = "",
) -> tuple[list[tuple[str, str, int]], int]:
counts_sq = (
select(
AssetInfoTag.tag_name.label("tag_name"),
func.count(AssetInfoTag.asset_info_id).label("cnt"),
)
.select_from(AssetInfoTag)
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
.where(visible_owner_clause(owner_id))
.group_by(AssetInfoTag.tag_name)
.subquery()
)
q = (
select(
Tag.name,
Tag.tag_type,
func.coalesce(counts_sq.c.cnt, 0).label("count"),
)
.select_from(Tag)
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
)
if prefix:
escaped, esc = escape_like_prefix(prefix.strip().lower())
q = q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
if order == "name_asc":
q = q.order_by(Tag.name.asc())
else:
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
total_q = select(func.count()).select_from(Tag)
if prefix:
escaped, esc = escape_like_prefix(prefix.strip().lower())
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
total_q = total_q.where(
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
)
rows = (session.execute(q.limit(limit).offset(offset))).all()
total = (session.execute(total_q)).scalar_one()
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
return rows_norm, int(total or 0)

View File

@ -0,0 +1,62 @@
from typing import Iterable
import sqlalchemy
from sqlalchemy.orm import Session
from sqlalchemy.dialects import sqlite
from app.assets.helpers import normalize_tags, utcnow
from app.assets.database.models import Tag, AssetInfoTag, AssetInfo
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
wanted = normalize_tags(list(names))
if not wanted:
return
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
ins = (
sqlite.insert(Tag)
.values(rows)
.on_conflict_do_nothing(index_elements=[Tag.name])
)
return session.execute(ins)
def add_missing_tag_for_asset_id(
session: Session,
*,
asset_id: str,
origin: str = "automatic",
) -> None:
select_rows = (
sqlalchemy.select(
AssetInfo.id.label("asset_info_id"),
sqlalchemy.literal("missing").label("tag_name"),
sqlalchemy.literal(origin).label("origin"),
sqlalchemy.literal(utcnow()).label("added_at"),
)
.where(AssetInfo.asset_id == asset_id)
.where(
sqlalchemy.not_(
sqlalchemy.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing"))
)
)
)
session.execute(
sqlite.insert(AssetInfoTag)
.from_select(
["asset_info_id", "tag_name", "origin", "added_at"],
select_rows,
)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
def remove_missing_tag_for_asset_id(
session: Session,
*,
asset_id: str,
) -> None:
session.execute(
sqlalchemy.delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id.in_(sqlalchemy.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
AssetInfoTag.tag_name == "missing",
)
)

75
app/assets/hashing.py Normal file
View File

@ -0,0 +1,75 @@
from blake3 import blake3
from typing import IO
import os
import asyncio
DEFAULT_CHUNK = 8 * 1024 *1024 # 8MB
# NOTE: this allows hashing different representations of a file-like object
def blake3_hash(
fp: str | IO[bytes],
chunk_size: int = DEFAULT_CHUNK,
) -> str:
"""
Returns a BLAKE3 hex digest for ``fp``, which may be:
- a filename (str/bytes) or PathLike
- an open binary file object
If ``fp`` is a file object, it must be opened in **binary** mode and support
``read``, ``seek``, and ``tell``. The function will seek to the start before
reading and will attempt to restore the original position afterward.
"""
# duck typing to check if input is a file-like object
if hasattr(fp, "read"):
return _hash_file_obj(fp, chunk_size)
with open(os.fspath(fp), "rb") as f:
return _hash_file_obj(f, chunk_size)
async def blake3_hash_async(
fp: str | IO[bytes],
chunk_size: int = DEFAULT_CHUNK,
) -> str:
"""Async wrapper for ``blake3_hash_sync``.
Uses a worker thread so the event loop remains responsive.
"""
# If it is a path, open inside the worker thread to keep I/O off the loop.
if hasattr(fp, "read"):
return await asyncio.to_thread(blake3_hash, fp, chunk_size)
def _worker() -> str:
with open(os.fspath(fp), "rb") as f:
return _hash_file_obj(f, chunk_size)
return await asyncio.to_thread(_worker)
def _hash_file_obj(file_obj: IO, chunk_size: int = DEFAULT_CHUNK) -> str:
"""
Hash an already-open binary file object by streaming in chunks.
- Seeks to the beginning before reading (if supported).
- Restores the original position afterward (if tell/seek are supported).
"""
if chunk_size <= 0:
chunk_size = DEFAULT_CHUNK
# in case file object is already open and not at the beginning, track so can be restored after hashing
orig_pos = file_obj.tell()
try:
# seek to the beginning before reading
if orig_pos != 0:
file_obj.seek(0)
h = blake3()
while True:
chunk = file_obj.read(chunk_size)
if not chunk:
break
h.update(chunk)
return h.hexdigest()
finally:
# restore original position in file object, if needed
if orig_pos != 0:
file_obj.seek(orig_pos)

217
app/assets/helpers.py Normal file
View File

@ -0,0 +1,217 @@
import contextlib
import os
from aiohttp import web
from datetime import datetime, timezone
from pathlib import Path
from typing import Literal, Any
import folder_paths
RootType = Literal["models", "input", "output"]
ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output")
def get_query_dict(request: web.Request) -> dict[str, Any]:
"""
Gets a dictionary of query parameters from the request.
'request.query' is a MultiMapping[str], needs to be converted to a dictionary to be validated by Pydantic.
"""
query_dict = {
key: request.query.getall(key) if len(request.query.getall(key)) > 1 else request.query.get(key)
for key in request.query.keys()
}
return query_dict
def list_tree(base_dir: str) -> list[str]:
out: list[str] = []
base_abs = os.path.abspath(base_dir)
if not os.path.isdir(base_abs):
return out
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
for name in filenames:
out.append(os.path.abspath(os.path.join(dirpath, name)))
return out
def prefixes_for_root(root: RootType) -> list[str]:
if root == "models":
bases: list[str] = []
for _bucket, paths in get_comfy_models_folders():
bases.extend(paths)
return [os.path.abspath(p) for p in bases]
if root == "input":
return [os.path.abspath(folder_paths.get_input_directory())]
if root == "output":
return [os.path.abspath(folder_paths.get_output_directory())]
return []
def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]:
"""Escapes %, _ and the escape char itself in a LIKE prefix.
Returns (escaped_prefix, escape_char). Caller should append '%' and pass escape=escape_char to .like().
"""
s = s.replace(escape, escape + escape) # escape the escape char first
s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards
return s, escape
def fast_asset_file_check(
*,
mtime_db: int | None,
size_db: int | None,
stat_result: os.stat_result,
) -> bool:
if mtime_db is None:
return False
actual_mtime_ns = getattr(stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000))
if int(mtime_db) != int(actual_mtime_ns):
return False
sz = int(size_db or 0)
if sz > 0:
return int(stat_result.st_size) == sz
return True
def utcnow() -> datetime:
"""Naive UTC timestamp (no tzinfo). We always treat DB datetimes as UTC."""
return datetime.now(timezone.utc).replace(tzinfo=None)
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
"""Build a list of (folder_name, base_paths[]) categories that are configured for model locations.
We trust `folder_paths.folder_names_and_paths` and include a category if
*any* of its base paths lies under the Comfy `models_dir`.
"""
targets: list[tuple[str, list[str]]] = []
models_root = os.path.abspath(folder_paths.models_dir)
for name, values in folder_paths.folder_names_and_paths.items():
paths, _exts = values[0], values[1] # NOTE: this prevents nodepacks that hackily edit folder_... from breaking ComfyUI
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
targets.append((name, paths))
return targets
def compute_relative_filename(file_path: str) -> str | None:
"""
Return the model's path relative to the last well-known folder (the model category),
using forward slashes, eg:
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
For non-model paths, returns None.
NOTE: this is a temporary helper, used only for initializing metadata["filename"] field.
"""
try:
root_category, rel_path = get_relative_to_root_category_path_of_asset(file_path)
except ValueError:
return None
p = Path(rel_path)
parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
if not parts:
return None
if root_category == "models":
# parts[0] is the category ("checkpoints", "vae", etc) drop it
inside = parts[1:] if len(parts) > 1 else [parts[0]]
return "/".join(inside)
return "/".join(parts) # input/output: keep all parts
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
"""Given an absolute or relative file path, determine which root category the path belongs to:
- 'input' if the file resides under `folder_paths.get_input_directory()`
- 'output' if the file resides under `folder_paths.get_output_directory()`
- 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()`
Returns:
(root_category, relative_path_inside_that_root)
For 'models', the relative path is prefixed with the category name:
e.g. ('models', 'vae/test/sub/ae.safetensors')
Raises:
ValueError: if the path does not belong to input, output, or configured model bases.
"""
fp_abs = os.path.abspath(file_path)
def _is_within(child: str, parent: str) -> bool:
try:
return os.path.commonpath([child, parent]) == parent
except Exception:
return False
def _rel(child: str, parent: str) -> str:
return os.path.relpath(os.path.join(os.sep, os.path.relpath(child, parent)), os.sep)
# 1) input
input_base = os.path.abspath(folder_paths.get_input_directory())
if _is_within(fp_abs, input_base):
return "input", _rel(fp_abs, input_base)
# 2) output
output_base = os.path.abspath(folder_paths.get_output_directory())
if _is_within(fp_abs, output_base):
return "output", _rel(fp_abs, output_base)
# 3) models (check deepest matching base to avoid ambiguity)
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
for bucket, bases in get_comfy_models_folders():
for b in bases:
base_abs = os.path.abspath(b)
if not _is_within(fp_abs, base_abs):
continue
cand = (len(base_abs), bucket, _rel(fp_abs, base_abs))
if best is None or cand[0] > best[0]:
best = cand
if best is not None:
_, bucket, rel_inside = best
combined = os.path.join(bucket, rel_inside)
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
raise ValueError(f"Path is not within input, output, or configured model bases: {file_path}")
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
"""Return a tuple (name, tags) derived from a filesystem path.
Semantics:
- Root category is determined by `get_relative_to_root_category_path_of_asset`.
- The returned `name` is the base filename with extension from the relative path.
- The returned `tags` are:
[root_category] + parent folders of the relative path (in order)
For 'models', this means:
file '/.../ModelsDir/vae/test_tag/ae.safetensors'
-> root_category='models', some_path='vae/test_tag/ae.safetensors'
-> name='ae.safetensors', tags=['models', 'vae', 'test_tag']
Raises:
ValueError: if the path does not belong to input, output, or configured model bases.
"""
root_category, some_path = get_relative_to_root_category_path_of_asset(file_path)
p = Path(some_path)
parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)]
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
def normalize_tags(tags: list[str] | None) -> list[str]:
"""
Normalize a list of tags by:
- Stripping whitespace and converting to lowercase.
- Removing duplicates.
"""
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
def collect_models_files() -> list[str]:
out: list[str] = []
for folder_name, bases in get_comfy_models_folders():
rel_files = folder_paths.get_filename_list(folder_name) or []
for rel_path in rel_files:
abs_path = folder_paths.get_full_path(folder_name, rel_path)
if not abs_path:
continue
abs_path = os.path.abspath(abs_path)
allowed = False
for b in bases:
base_abs = os.path.abspath(b)
with contextlib.suppress(Exception):
if os.path.commonpath([abs_path, base_abs]) == base_abs:
allowed = True
break
if allowed:
out.append(abs_path)
return out

123
app/assets/manager.py Normal file
View File

@ -0,0 +1,123 @@
from typing import Sequence
from app.database.db import create_session
from app.assets.api import schemas_out
from app.assets.database.queries import (
asset_exists_by_hash,
fetch_asset_info_asset_and_tags,
list_asset_infos_page,
list_tags_with_usage,
)
def _safe_sort_field(requested: str | None) -> str:
if not requested:
return "created_at"
v = requested.lower()
if v in {"name", "created_at", "updated_at", "size", "last_access_time"}:
return v
return "created_at"
def asset_exists(asset_hash: str) -> bool:
with create_session() as session:
return asset_exists_by_hash(session, asset_hash=asset_hash)
def list_assets(
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
owner_id: str = "",
) -> schemas_out.AssetsList:
sort = _safe_sort_field(sort)
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
with create_session() as session:
infos, tag_map, total = list_asset_infos_page(
session,
owner_id=owner_id,
include_tags=include_tags,
exclude_tags=exclude_tags,
name_contains=name_contains,
metadata_filter=metadata_filter,
limit=limit,
offset=offset,
sort=sort,
order=order,
)
summaries: list[schemas_out.AssetSummary] = []
for info in infos:
asset = info.asset
tags = tag_map.get(info.id, [])
summaries.append(
schemas_out.AssetSummary(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset else None,
mime_type=asset.mime_type if asset else None,
tags=tags,
preview_url=f"/api/assets/{info.id}/content",
created_at=info.created_at,
updated_at=info.updated_at,
last_access_time=info.last_access_time,
)
)
return schemas_out.AssetsList(
assets=summaries,
total=total,
has_more=(offset + len(summaries)) < total,
)
def get_asset(asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail:
with create_session() as session:
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not res:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info, asset, tag_names = res
preview_id = info.preview_id
return schemas_out.AssetDetail(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
mime_type=asset.mime_type if asset else None,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
)
def list_tags(
prefix: str | None = None,
limit: int = 100,
offset: int = 0,
order: str = "count_desc",
include_zero: bool = True,
owner_id: str = "",
) -> schemas_out.TagsList:
limit = max(1, min(1000, limit))
offset = max(0, offset)
with create_session() as session:
rows, total = list_tags_with_usage(
session,
prefix=prefix,
limit=limit,
offset=offset,
include_zero=include_zero,
order=order,
owner_id=owner_id,
)
tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows]
return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total)

229
app/assets/scanner.py Normal file
View File

@ -0,0 +1,229 @@
import contextlib
import time
import logging
import os
import sqlalchemy
import folder_paths
from app.database.db import create_session, dependencies_available
from app.assets.helpers import (
collect_models_files, compute_relative_filename, fast_asset_file_check, get_name_and_tags_from_asset_path,
list_tree,prefixes_for_root, escape_like_prefix,
RootType
)
from app.assets.database.tags import add_missing_tag_for_asset_id, ensure_tags_exist, remove_missing_tag_for_asset_id
from app.assets.database.bulk_ops import seed_from_paths_batch
from app.assets.database.models import Asset, AssetCacheState, AssetInfo
def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> None:
"""
Scan the given roots and seed the assets into the database.
"""
if not dependencies_available():
if enable_logging:
logging.warning("Database dependencies not available, skipping assets scan")
return
t_start = time.perf_counter()
created = 0
skipped_existing = 0
paths: list[str] = []
try:
existing_paths: set[str] = set()
for r in roots:
try:
survivors: set[str] = _fast_db_consistency_pass(r, collect_existing_paths=True, update_missing_tags=True)
if survivors:
existing_paths.update(survivors)
except Exception as e:
logging.exception("fast DB scan failed for %s: %s", r, e)
if "models" in roots:
paths.extend(collect_models_files())
if "input" in roots:
paths.extend(list_tree(folder_paths.get_input_directory()))
if "output" in roots:
paths.extend(list_tree(folder_paths.get_output_directory()))
specs: list[dict] = []
tag_pool: set[str] = set()
for p in paths:
abs_p = os.path.abspath(p)
if abs_p in existing_paths:
skipped_existing += 1
continue
try:
stat_p = os.stat(abs_p, follow_symlinks=False)
except OSError:
continue
# skip empty files
if not stat_p.st_size:
continue
name, tags = get_name_and_tags_from_asset_path(abs_p)
specs.append(
{
"abs_path": abs_p,
"size_bytes": stat_p.st_size,
"mtime_ns": getattr(stat_p, "st_mtime_ns", int(stat_p.st_mtime * 1_000_000_000)),
"info_name": name,
"tags": tags,
"fname": compute_relative_filename(abs_p),
}
)
for t in tags:
tag_pool.add(t)
# if no file specs, nothing to do
if not specs:
return
with create_session() as sess:
if tag_pool:
ensure_tags_exist(sess, tag_pool, tag_type="user")
result = seed_from_paths_batch(sess, specs=specs, owner_id="")
created += result["inserted_infos"]
sess.commit()
finally:
if enable_logging:
logging.info(
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, total_seen=%d)",
roots,
time.perf_counter() - t_start,
created,
skipped_existing,
len(paths),
)
def _fast_db_consistency_pass(
root: RootType,
*,
collect_existing_paths: bool = False,
update_missing_tags: bool = False,
) -> set[str] | None:
"""Fast DB+FS pass for a root:
- Toggle needs_verify per state using fast check
- For hashed assets with at least one fast-ok state in this root: delete stale missing states
- For seed assets with all states missing: delete Asset and its AssetInfos
- Optionally add/remove 'missing' tags based on fast-ok in this root
- Optionally return surviving absolute paths
"""
prefixes = prefixes_for_root(root)
if not prefixes:
return set() if collect_existing_paths else None
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
escaped, esc = escape_like_prefix(base)
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
with create_session() as sess:
rows = (
sess.execute(
sqlalchemy.select(
AssetCacheState.id,
AssetCacheState.file_path,
AssetCacheState.mtime_ns,
AssetCacheState.needs_verify,
AssetCacheState.asset_id,
Asset.hash,
Asset.size_bytes,
)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(sqlalchemy.or_(*conds))
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
)
).all()
by_asset: dict[str, dict] = {}
for sid, fp, mtime_db, needs_verify, aid, a_hash, a_size in rows:
acc = by_asset.get(aid)
if acc is None:
acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []}
by_asset[aid] = acc
fast_ok = False
try:
exists = True
fast_ok = fast_asset_file_check(
mtime_db=mtime_db,
size_db=acc["size_db"],
stat_result=os.stat(fp, follow_symlinks=True),
)
except FileNotFoundError:
exists = False
except OSError:
exists = False
acc["states"].append({
"sid": sid,
"fp": fp,
"exists": exists,
"fast_ok": fast_ok,
"needs_verify": bool(needs_verify),
})
to_set_verify: list[int] = []
to_clear_verify: list[int] = []
stale_state_ids: list[int] = []
survivors: set[str] = set()
for aid, acc in by_asset.items():
a_hash = acc["hash"]
states = acc["states"]
any_fast_ok = any(s["fast_ok"] for s in states)
all_missing = all(not s["exists"] for s in states)
for s in states:
if not s["exists"]:
continue
if s["fast_ok"] and s["needs_verify"]:
to_clear_verify.append(s["sid"])
if not s["fast_ok"] and not s["needs_verify"]:
to_set_verify.append(s["sid"])
if a_hash is None:
if states and all_missing: # remove seed Asset completely, if no valid AssetCache exists
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id == aid))
asset = sess.get(Asset, aid)
if asset:
sess.delete(asset)
else:
for s in states:
if s["exists"]:
survivors.add(os.path.abspath(s["fp"]))
continue
if any_fast_ok: # if Asset has at least one valid AssetCache record, remove any invalid AssetCache records
for s in states:
if not s["exists"]:
stale_state_ids.append(s["sid"])
if update_missing_tags:
with contextlib.suppress(Exception):
remove_missing_tag_for_asset_id(sess, asset_id=aid)
elif update_missing_tags:
with contextlib.suppress(Exception):
add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
for s in states:
if s["exists"]:
survivors.add(os.path.abspath(s["fp"]))
if stale_state_ids:
sess.execute(sqlalchemy.delete(AssetCacheState).where(AssetCacheState.id.in_(stale_state_ids)))
if to_set_verify:
sess.execute(
sqlalchemy.update(AssetCacheState)
.where(AssetCacheState.id.in_(to_set_verify))
.values(needs_verify=True)
)
if to_clear_verify:
sess.execute(
sqlalchemy.update(AssetCacheState)
.where(AssetCacheState.id.in_(to_clear_verify))
.values(needs_verify=False)
)
sess.commit()
return survivors if collect_existing_paths else None

View File

@ -1,14 +1,21 @@
from sqlalchemy.orm import declarative_base
from typing import Any
from datetime import datetime
from sqlalchemy.orm import DeclarativeBase
Base = declarative_base()
class Base(DeclarativeBase):
pass
def to_dict(obj):
def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
fields = obj.__table__.columns.keys()
return {
field: (val.to_dict() if hasattr(val, "to_dict") else val)
for field in fields
if (val := getattr(obj, field))
}
out: dict[str, Any] = {}
for field in fields:
val = getattr(obj, field)
if val is None and not include_none:
continue
if isinstance(val, datetime):
out[field] = val.isoformat()
else:
out[field] = val
return out
# TODO: Define models here

View File

@ -44,7 +44,7 @@ class ModelFileManager:
@routes.get("/experiment/models/{folder}")
async def get_all_models(request):
folder = request.match_info.get("folder", None)
if not folder in folder_paths.folder_names_and_paths:
if folder not in folder_paths.folder_names_and_paths:
return web.Response(status=404)
files = self.get_model_file_list(folder)
return web.json_response(files)
@ -55,7 +55,7 @@ class ModelFileManager:
path_index = int(request.match_info.get("path_index", None))
filename = request.match_info.get("filename", None)
if not folder_name in folder_paths.folder_names_and_paths:
if folder_name not in folder_paths.folder_names_and_paths:
return web.Response(status=404)
folders = folder_paths.folder_names_and_paths[folder_name]

View File

@ -10,6 +10,7 @@ import hashlib
class Source:
custom_node = "custom_node"
templates = "templates"
class SubgraphEntry(TypedDict):
source: str
@ -38,6 +39,18 @@ class CustomNodeSubgraphEntryInfo(TypedDict):
class SubgraphManager:
def __init__(self):
self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None
self.cached_blueprint_subgraphs: dict[SubgraphEntry] | None = None
def _create_entry(self, file: str, source: str, node_pack: str) -> tuple[str, SubgraphEntry]:
"""Create a subgraph entry from a file path. Expects normalized path (forward slashes)."""
entry_id = hashlib.sha256(f"{source}{file}".encode()).hexdigest()
entry: SubgraphEntry = {
"source": source,
"name": os.path.splitext(os.path.basename(file))[0],
"path": file,
"info": {"node_pack": node_pack},
}
return entry_id, entry
async def load_entry_data(self, entry: SubgraphEntry):
with open(entry['path'], 'r') as f:
@ -60,53 +73,60 @@ class SubgraphManager:
return entries
async def get_custom_node_subgraphs(self, loadedModules, force_reload=False):
# if not forced to reload and cached, return cache
"""Load subgraphs from custom nodes."""
if not force_reload and self.cached_custom_node_subgraphs is not None:
return self.cached_custom_node_subgraphs
# Load subgraphs from custom nodes
subfolder = "subgraphs"
subgraphs_dict: dict[SubgraphEntry] = {}
subgraphs_dict: dict[SubgraphEntry] = {}
for folder in folder_paths.get_folder_paths("custom_nodes"):
pattern = os.path.join(folder, f"*/{subfolder}/*.json")
matched_files = glob.glob(pattern)
for file in matched_files:
# replace backslashes with forward slashes
pattern = os.path.join(folder, "*/subgraphs/*.json")
for file in glob.glob(pattern):
file = file.replace('\\', '/')
info: CustomNodeSubgraphEntryInfo = {
"node_pack": "custom_nodes." + file.split('/')[-3]
}
source = Source.custom_node
# hash source + path to make sure id will be as unique as possible, but
# reproducible across backend reloads
id = hashlib.sha256(f"{source}{file}".encode()).hexdigest()
entry: SubgraphEntry = {
"source": Source.custom_node,
"name": os.path.splitext(os.path.basename(file))[0],
"path": file,
"info": info,
}
subgraphs_dict[id] = entry
node_pack = "custom_nodes." + file.split('/')[-3]
entry_id, entry = self._create_entry(file, Source.custom_node, node_pack)
subgraphs_dict[entry_id] = entry
self.cached_custom_node_subgraphs = subgraphs_dict
return subgraphs_dict
async def get_custom_node_subgraph(self, id: str, loadedModules):
subgraphs = await self.get_custom_node_subgraphs(loadedModules)
entry: SubgraphEntry = subgraphs.get(id, None)
if entry is not None and entry.get('data', None) is None:
async def get_blueprint_subgraphs(self, force_reload=False):
"""Load subgraphs from the blueprints directory."""
if not force_reload and self.cached_blueprint_subgraphs is not None:
return self.cached_blueprint_subgraphs
subgraphs_dict: dict[SubgraphEntry] = {}
blueprints_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'blueprints')
if os.path.exists(blueprints_dir):
for file in glob.glob(os.path.join(blueprints_dir, "*.json")):
file = file.replace('\\', '/')
entry_id, entry = self._create_entry(file, Source.templates, "comfyui")
subgraphs_dict[entry_id] = entry
self.cached_blueprint_subgraphs = subgraphs_dict
return subgraphs_dict
async def get_all_subgraphs(self, loadedModules, force_reload=False):
"""Get all subgraphs from all sources (custom nodes and blueprints)."""
custom_node_subgraphs = await self.get_custom_node_subgraphs(loadedModules, force_reload)
blueprint_subgraphs = await self.get_blueprint_subgraphs(force_reload)
return {**custom_node_subgraphs, **blueprint_subgraphs}
async def get_subgraph(self, id: str, loadedModules):
"""Get a specific subgraph by ID from any source."""
entry = (await self.get_all_subgraphs(loadedModules)).get(id)
if entry is not None and entry.get('data') is None:
await self.load_entry_data(entry)
return entry
def add_routes(self, routes, loadedModules):
@routes.get("/global_subgraphs")
async def get_global_subgraphs(request):
subgraphs_dict = await self.get_custom_node_subgraphs(loadedModules)
# NOTE: we may want to include other sources of global subgraphs such as templates in the future;
# that's the reasoning for the current implementation
subgraphs_dict = await self.get_all_subgraphs(loadedModules)
return web.json_response(await self.sanitize_entries(subgraphs_dict, remove_data=True))
@routes.get("/global_subgraphs/{id}")
async def get_global_subgraph(request):
id = request.match_info.get("id", None)
subgraph = await self.get_custom_node_subgraph(id, loadedModules)
subgraph = await self.get_subgraph(id, loadedModules)
return web.json_response(await self.sanitize_entry(subgraph))

View File

View File

@ -237,6 +237,7 @@ database_default_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
)
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
parser.add_argument("--disable-assets-autoscan", action="store_true", help="Disable asset scanning on startup for database synchronization.")
if comfy.options.args_parsing:
args = parser.parse_args()

View File

@ -1,6 +1,7 @@
import torch
from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.ops
import math
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
image = image[:, :, :, :3] if image.shape[3] > 3 else image
@ -21,6 +22,39 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s
image = torch.clip((255. * image), 0, 255).round() / 255.0
return (image - mean.view([3,1,1])) / std.view([3,1,1])
def siglip2_flex_calc_resolution(oh, ow, patch_size, max_num_patches, eps=1e-5):
def scale_dim(size, scale):
scaled = math.ceil(size * scale / patch_size) * patch_size
return max(patch_size, int(scaled))
# Binary search for optimal scale
lo, hi = eps / 10, 100.0
while hi - lo >= eps:
mid = (lo + hi) / 2
h, w = scale_dim(oh, mid), scale_dim(ow, mid)
if (h // patch_size) * (w // patch_size) <= max_num_patches:
lo = mid
else:
hi = mid
return scale_dim(oh, lo), scale_dim(ow, lo)
def siglip2_preprocess(image, size, patch_size, num_patches, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True):
if size > 0:
return clip_preprocess(image, size=size, mean=mean, std=std, crop=crop)
image = image[:, :, :, :3] if image.shape[3] > 3 else image
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
std = torch.tensor(std, device=image.device, dtype=image.dtype)
image = image.movedim(-1, 1)
b, c, h, w = image.shape
h, w = siglip2_flex_calc_resolution(h, w, patch_size, num_patches)
image = torch.nn.functional.interpolate(image, size=(h, w), mode="bilinear", antialias=True)
image = torch.clip((255. * image), 0, 255).round() / 255.0
return (image - mean.view([3, 1, 1])) / std.view([3, 1, 1])
class CLIPAttention(torch.nn.Module):
def __init__(self, embed_dim, heads, dtype, device, operations):
super().__init__()
@ -175,6 +209,27 @@ class CLIPTextModel(torch.nn.Module):
out = self.text_projection(x[2])
return (x[0], x[1], out, x[2])
def siglip2_pos_embed(embed_weight, embeds, orig_shape):
embed_weight_len = round(embed_weight.shape[0] ** 0.5)
embed_weight = comfy.ops.cast_to_input(embed_weight, embeds).movedim(1, 0).reshape(1, -1, embed_weight_len, embed_weight_len)
embed_weight = torch.nn.functional.interpolate(embed_weight, size=orig_shape, mode="bilinear", align_corners=False, antialias=True)
embed_weight = embed_weight.reshape(-1, embed_weight.shape[-2] * embed_weight.shape[-1]).movedim(0, 1)
return embeds + embed_weight
class Siglip2Embeddings(torch.nn.Module):
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", num_patches=None, dtype=None, device=None, operations=None):
super().__init__()
self.patch_embedding = operations.Linear(num_channels * patch_size * patch_size, embed_dim, dtype=dtype, device=device)
self.position_embedding = operations.Embedding(num_patches, embed_dim, dtype=dtype, device=device)
self.patch_size = patch_size
def forward(self, pixel_values):
b, c, h, w = pixel_values.shape
img = pixel_values.movedim(1, -1).reshape(b, h // self.patch_size, self.patch_size, w // self.patch_size, self.patch_size, c)
img = img.permute(0, 1, 3, 2, 4, 5)
img = img.reshape(b, img.shape[1] * img.shape[2], -1)
img = self.patch_embedding(img)
return siglip2_pos_embed(self.position_embedding.weight, img, (h // self.patch_size, w // self.patch_size))
class CLIPVisionEmbeddings(torch.nn.Module):
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", dtype=None, device=None, operations=None):
@ -218,8 +273,11 @@ class CLIPVision(torch.nn.Module):
intermediate_activation = config_dict["hidden_act"]
model_type = config_dict["model_type"]
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
if model_type == "siglip_vision_model":
if model_type in ["siglip2_vision_model"]:
self.embeddings = Siglip2Embeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, num_patches=config_dict.get("num_patches", None), dtype=dtype, device=device, operations=operations)
else:
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
if model_type in ["siglip_vision_model", "siglip2_vision_model"]:
self.pre_layrnorm = lambda a: a
self.output_layernorm = True
else:

View File

@ -21,6 +21,7 @@ clip_preprocess = comfy.clip_model.clip_preprocess # Prevent some stuff from br
IMAGE_ENCODERS = {
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
"siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection,
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
}
@ -32,9 +33,10 @@ class ClipVisionModel():
self.image_size = config.get("image_size", 224)
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
model_type = config.get("model_type", "clip_vision_model")
model_class = IMAGE_ENCODERS.get(model_type)
if model_type == "siglip_vision_model":
self.model_type = config.get("model_type", "clip_vision_model")
self.config = config.copy()
model_class = IMAGE_ENCODERS.get(self.model_type)
if self.model_type == "siglip_vision_model":
self.return_all_hidden_states = True
else:
self.return_all_hidden_states = False
@ -55,7 +57,10 @@ class ClipVisionModel():
def encode_image(self, image, crop=True):
comfy.model_management.load_model_gpu(self.patcher)
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
if self.model_type == "siglip2_vision_model":
pixel_values = comfy.clip_model.siglip2_preprocess(image.to(self.load_device), size=self.image_size, patch_size=self.config.get("patch_size", 16), num_patches=self.config.get("num_patches", 256), mean=self.image_mean, std=self.image_std, crop=crop).float()
else:
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
outputs = Output()
@ -107,10 +112,14 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
embed_shape = sd["vision_model.embeddings.position_embedding.weight"].shape[0]
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
if embed_shape == 729:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
elif embed_shape == 1024:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json")
patch_embedding_shape = sd["vision_model.embeddings.patch_embedding.weight"].shape
if len(patch_embedding_shape) == 2:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip2_base_naflex.json")
else:
if embed_shape == 729:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
elif embed_shape == 1024:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json")
elif embed_shape == 577:
if "multi_modal_projector.linear_1.bias" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json")

View File

@ -0,0 +1,14 @@
{
"num_channels": 3,
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1152,
"image_size": -1,
"intermediate_size": 4304,
"model_type": "siglip2_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 27,
"patch_size": 16,
"num_patches": 256,
"image_mean": [0.5, 0.5, 0.5],
"image_std": [0.5, 0.5, 0.5]
}

View File

@ -65,3 +65,121 @@ def stochastic_rounding(value, dtype, seed=0):
return output
return value.to(dtype=dtype)
# TODO: improve this?
def stochastic_float_to_fp4_e2m1(x, generator):
orig_shape = x.shape
sign = torch.signbit(x).to(torch.uint8)
exp = torch.floor(torch.log2(x.abs()) + 1.0).clamp(0, 3)
x += (torch.rand(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator) - 0.5) * (2 ** (exp - 2.0)) * 1.25
x = x.abs()
exp = torch.floor(torch.log2(x) + 1.1925).clamp(0, 3)
mantissa = torch.where(
exp > 0,
(x / (2.0 ** (exp - 1)) - 1.0) * 2.0,
(x * 2.0),
out=x
).round().to(torch.uint8)
del x
exp = exp.to(torch.uint8)
fp4 = (sign << 3) | (exp << 1) | mantissa
del sign, exp, mantissa
fp4_flat = fp4.view(-1)
packed = (fp4_flat[0::2] << 4) | fp4_flat[1::2]
return packed.reshape(list(orig_shape)[:-1] + [-1])
def to_blocked(input_matrix, flatten: bool = True) -> torch.Tensor:
"""
Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.
See:
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
Args:
input_matrix: Input tensor of shape (H, W)
Returns:
Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4))
"""
def ceil_div(a, b):
return (a + b - 1) // b
rows, cols = input_matrix.shape
n_row_blocks = ceil_div(rows, 128)
n_col_blocks = ceil_div(cols, 4)
# Calculate the padded shape
padded_rows = n_row_blocks * 128
padded_cols = n_col_blocks * 4
padded = input_matrix
if (rows, cols) != (padded_rows, padded_cols):
padded = torch.zeros(
(padded_rows, padded_cols),
device=input_matrix.device,
dtype=input_matrix.dtype,
)
padded[:rows, :cols] = input_matrix
# Rearrange the blocks
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
if flatten:
return rearranged.flatten()
return rearranged.reshape(padded_rows, padded_cols)
def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0):
F4_E2M1_MAX = 6.0
F8_E4M3_MAX = 448.0
def roundup(x: int, multiple: int) -> int:
"""Round up x to the nearest multiple."""
return ((x + multiple - 1) // multiple) * multiple
orig_shape = x.shape
# Handle padding
if pad_16x:
rows, cols = x.shape
padded_rows = roundup(rows, 16)
padded_cols = roundup(cols, 16)
if padded_rows != rows or padded_cols != cols:
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
# Note: We update orig_shape because the output tensor logic below assumes x.shape matches
# what we want to produce. If we pad here, we want the padded output.
orig_shape = x.shape
block_size = 16
x = x.reshape(orig_shape[0], -1, block_size)
max_abs = torch.amax(torch.abs(x), dim=-1)
block_scale = max_abs / F4_E2M1_MAX
scaled_block_scales = block_scale / per_tensor_scale.to(block_scale.dtype)
scaled_block_scales_fp8 = torch.clamp(scaled_block_scales, max=F8_E4M3_MAX).to(torch.float8_e4m3fn)
total_scale = per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)
# Handle zero blocks (from padding): avoid 0/0 NaN
zero_scale_mask = (total_scale == 0)
total_scale_safe = torch.where(zero_scale_mask, torch.ones_like(total_scale), total_scale)
x = x / total_scale_safe.unsqueeze(-1)
generator = torch.Generator(device=x.device)
generator.manual_seed(seed)
x = torch.where(zero_scale_mask.unsqueeze(-1), torch.zeros_like(x), x)
x = x.view(orig_shape)
data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator)
blocked_scales = to_blocked(scaled_block_scales_fp8, flatten=False)
return data_lp, blocked_scales

View File

@ -527,7 +527,8 @@ class HookKeyframeGroup:
if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0:
break
# if eval_c is outside the percent range, stop looking further
else: break
else:
break
# update steps current context is used
self._current_used_steps += 1
# update current timestep this was performed on

View File

@ -407,6 +407,11 @@ class LTXV(LatentFormat):
self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512]
class LTXAV(LTXV):
def __init__(self):
self.latent_rgb_factors = None
self.latent_rgb_factors_bias = None
class HunyuanVideo(LatentFormat):
latent_channels = 16
latent_dimensions = 3

View File

@ -270,7 +270,7 @@ class ChromaRadiance(Chroma):
bad_keys = tuple(
k
for k, v in overrides.items()
if type(v) != type(getattr(params, k)) and (v is not None or k not in nullable_keys)
if not isinstance(v, type(getattr(params, k))) and (v is not None or k not in nullable_keys)
)
if bad_keys:
e = f"Invalid value(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}"

View File

@ -4,6 +4,7 @@ from torch import Tensor
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
import logging
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
@ -13,7 +14,6 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transforme
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
return x
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
assert dim % 2 == 0
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled():
@ -28,13 +28,20 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.to(dtype=torch.float32, device=pos.device)
def apply_rope1(x: Tensor, freqs_cis: Tensor):
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
x_out = freqs_cis[..., 0] * x_[..., 0]
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
try:
import comfy.quant_ops
apply_rope = comfy.quant_ops.ck.apply_rope
apply_rope1 = comfy.quant_ops.ck.apply_rope1
except:
logging.warning("No comfy kitchen, using old apply_rope functions.")
def apply_rope1(x: Tensor, freqs_cis: Tensor):
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
return x_out.reshape(*x.shape).type_as(x)
x_out = freqs_cis[..., 0] * x_[..., 0]
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
return x_out.reshape(*x.shape).type_as(x)
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)

View File

@ -3,7 +3,8 @@ import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm
import model_management, model_patcher
import comfy.model_management
import comfy.model_patcher
class SRResidualCausalBlock3D(nn.Module):
def __init__(self, channels: int):
@ -102,13 +103,13 @@ UPSAMPLERS = {
class HunyuanVideo15SRModel():
def __init__(self, model_type, config):
self.load_device = model_management.vae_device()
offload_device = model_management.vae_offload_device()
self.dtype = model_management.vae_dtype(self.load_device)
self.load_device = comfy.model_management.vae_device()
offload_device = comfy.model_management.vae_offload_device()
self.dtype = comfy.model_management.vae_dtype(self.load_device)
self.model_class = UPSAMPLERS.get(model_type)
self.model = self.model_class(**config).eval()
self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=True)
@ -117,5 +118,5 @@ class HunyuanVideo15SRModel():
return self.model.state_dict()
def resample_latent(self, latent):
model_management.load_model_gpu(self.patcher)
comfy.model_management.load_model_gpu(self.patcher)
return self.model(latent.to(self.load_device))

View File

@ -0,0 +1,913 @@
from typing import Tuple
import torch
import torch.nn as nn
from comfy.ldm.lightricks.model import (
CrossAttention,
FeedForward,
AdaLayerNormSingle,
PixArtAlphaTextProjection,
LTXVModel,
)
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
import comfy.ldm.common_dit
class CompressedTimestep:
"""Store video timestep embeddings in compressed form using per-frame indexing."""
__slots__ = ('data', 'batch_size', 'num_frames', 'patches_per_frame', 'feature_dim')
def __init__(self, tensor: torch.Tensor, patches_per_frame: int):
"""
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame
patches_per_frame: Number of spatial patches per frame (height * width in latent space)
"""
self.batch_size, num_tokens, self.feature_dim = tensor.shape
# Check if compression is valid (num_tokens must be divisible by patches_per_frame)
if num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
self.patches_per_frame = patches_per_frame
self.num_frames = num_tokens // patches_per_frame
# Reshape to [batch, frames, patches_per_frame, feature_dim] and store one value per frame
# All patches in a frame are identical, so we only keep the first one
reshaped = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim)
self.data = reshaped[:, :, 0, :].contiguous() # [batch, frames, feature_dim]
else:
# Not divisible or too small - store directly without compression
self.patches_per_frame = 1
self.num_frames = num_tokens
self.data = tensor
def expand(self):
"""Expand back to original tensor."""
if self.patches_per_frame == 1:
return self.data
# [batch, frames, feature_dim] -> [batch, frames, patches_per_frame, feature_dim] -> [batch, tokens, feature_dim]
expanded = self.data.unsqueeze(2).expand(self.batch_size, self.num_frames, self.patches_per_frame, self.feature_dim)
return expanded.reshape(self.batch_size, -1, self.feature_dim)
def expand_for_computation(self, scale_shift_table: torch.Tensor, batch_size: int, indices: slice = slice(None, None)):
"""Compute ada values on compressed per-frame data, then expand spatially."""
num_ada_params = scale_shift_table.shape[0]
# No compression - compute directly
if self.patches_per_frame == 1:
num_tokens = self.data.shape[1]
dim_per_param = self.feature_dim // num_ada_params
reshaped = self.data.reshape(batch_size, num_tokens, num_ada_params, dim_per_param)[:, :, indices, :]
table_values = scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=self.data.device, dtype=self.data.dtype)
ada_values = (table_values + reshaped).unbind(dim=2)
return ada_values
# Compressed: compute on per-frame data then expand spatially
# Reshape: [batch, frames, feature_dim] -> [batch, frames, num_ada_params, dim_per_param]
frame_reshaped = self.data.reshape(batch_size, self.num_frames, num_ada_params, -1)[:, :, indices, :]
table_values = scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(
device=self.data.device, dtype=self.data.dtype
)
frame_ada = (table_values + frame_reshaped).unbind(dim=2)
# Expand each ada parameter spatially: [batch, frames, dim] -> [batch, frames, patches, dim] -> [batch, tokens, dim]
return tuple(
frame_val.unsqueeze(2).expand(batch_size, self.num_frames, self.patches_per_frame, -1)
.reshape(batch_size, -1, frame_val.shape[-1])
for frame_val in frame_ada
)
class BasicAVTransformerBlock(nn.Module):
def __init__(
self,
v_dim,
a_dim,
v_heads,
a_heads,
vd_head,
ad_head,
v_context_dim=None,
a_context_dim=None,
attn_precision=None,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.attn_precision = attn_precision
self.attn1 = CrossAttention(
query_dim=v_dim,
heads=v_heads,
dim_head=vd_head,
context_dim=None,
attn_precision=self.attn_precision,
dtype=dtype,
device=device,
operations=operations,
)
self.audio_attn1 = CrossAttention(
query_dim=a_dim,
heads=a_heads,
dim_head=ad_head,
context_dim=None,
attn_precision=self.attn_precision,
dtype=dtype,
device=device,
operations=operations,
)
self.attn2 = CrossAttention(
query_dim=v_dim,
context_dim=v_context_dim,
heads=v_heads,
dim_head=vd_head,
attn_precision=self.attn_precision,
dtype=dtype,
device=device,
operations=operations,
)
self.audio_attn2 = CrossAttention(
query_dim=a_dim,
context_dim=a_context_dim,
heads=a_heads,
dim_head=ad_head,
attn_precision=self.attn_precision,
dtype=dtype,
device=device,
operations=operations,
)
# Q: Video, K,V: Audio
self.audio_to_video_attn = CrossAttention(
query_dim=v_dim,
context_dim=a_dim,
heads=a_heads,
dim_head=ad_head,
attn_precision=self.attn_precision,
dtype=dtype,
device=device,
operations=operations,
)
# Q: Audio, K,V: Video
self.video_to_audio_attn = CrossAttention(
query_dim=a_dim,
context_dim=v_dim,
heads=a_heads,
dim_head=ad_head,
attn_precision=self.attn_precision,
dtype=dtype,
device=device,
operations=operations,
)
self.ff = FeedForward(
v_dim, dim_out=v_dim, glu=True, dtype=dtype, device=device, operations=operations
)
self.audio_ff = FeedForward(
a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations
)
self.scale_shift_table = nn.Parameter(torch.empty(6, v_dim, device=device, dtype=dtype))
self.audio_scale_shift_table = nn.Parameter(
torch.empty(6, a_dim, device=device, dtype=dtype)
)
self.scale_shift_table_a2v_ca_audio = nn.Parameter(
torch.empty(5, a_dim, device=device, dtype=dtype)
)
self.scale_shift_table_a2v_ca_video = nn.Parameter(
torch.empty(5, v_dim, device=device, dtype=dtype)
)
def get_ada_values(
self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice = slice(None, None)
):
if isinstance(timestep, CompressedTimestep):
return timestep.expand_for_computation(scale_shift_table, batch_size, indices)
num_ada_params = scale_shift_table.shape[0]
ada_values = (
scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=timestep.device, dtype=timestep.dtype)
+ timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1)[:, :, indices, :]
).unbind(dim=2)
return ada_values
def get_av_ca_ada_values(
self,
scale_shift_table: torch.Tensor,
batch_size: int,
scale_shift_timestep: torch.Tensor,
gate_timestep: torch.Tensor,
num_scale_shift_values: int = 4,
):
scale_shift_ada_values = self.get_ada_values(
scale_shift_table[:num_scale_shift_values, :],
batch_size,
scale_shift_timestep,
)
gate_ada_values = self.get_ada_values(
scale_shift_table[num_scale_shift_values:, :],
batch_size,
gate_timestep,
)
return (*scale_shift_ada_values, *gate_ada_values)
def forward(
self,
x: Tuple[torch.Tensor, torch.Tensor],
v_context=None,
a_context=None,
attention_mask=None,
v_timestep=None,
a_timestep=None,
v_pe=None,
a_pe=None,
v_cross_pe=None,
a_cross_pe=None,
v_cross_scale_shift_timestep=None,
a_cross_scale_shift_timestep=None,
v_cross_gate_timestep=None,
a_cross_gate_timestep=None,
transformer_options=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
run_vx = transformer_options.get("run_vx", True)
run_ax = transformer_options.get("run_ax", True)
vx, ax = x
run_ax = run_ax and ax.numel() > 0
run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0
run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True)
if run_vx:
vshift_msa, vscale_msa, vgate_msa = (
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 3))
)
norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa
vx += self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) * vgate_msa
vx += self.attn2(
comfy.ldm.common_dit.rms_norm(vx),
context=v_context,
mask=attention_mask,
transformer_options=transformer_options,
)
del vshift_msa, vscale_msa, vgate_msa
if run_ax:
ashift_msa, ascale_msa, agate_msa = (
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 3))
)
norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa
ax += (
self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options)
* agate_msa
)
ax += self.audio_attn2(
comfy.ldm.common_dit.rms_norm(ax),
context=a_context,
mask=attention_mask,
transformer_options=transformer_options,
)
del ashift_msa, ascale_msa, agate_msa
# Audio - Video cross attention.
if run_a2v or run_v2a:
# norm3
vx_norm3 = comfy.ldm.common_dit.rms_norm(vx)
ax_norm3 = comfy.ldm.common_dit.rms_norm(ax)
(
scale_ca_audio_hidden_states_a2v,
shift_ca_audio_hidden_states_a2v,
scale_ca_audio_hidden_states_v2a,
shift_ca_audio_hidden_states_v2a,
gate_out_v2a,
) = self.get_av_ca_ada_values(
self.scale_shift_table_a2v_ca_audio,
ax.shape[0],
a_cross_scale_shift_timestep,
a_cross_gate_timestep,
)
(
scale_ca_video_hidden_states_a2v,
shift_ca_video_hidden_states_a2v,
scale_ca_video_hidden_states_v2a,
shift_ca_video_hidden_states_v2a,
gate_out_a2v,
) = self.get_av_ca_ada_values(
self.scale_shift_table_a2v_ca_video,
vx.shape[0],
v_cross_scale_shift_timestep,
v_cross_gate_timestep,
)
if run_a2v:
vx_scaled = (
vx_norm3 * (1 + scale_ca_video_hidden_states_a2v)
+ shift_ca_video_hidden_states_a2v
)
ax_scaled = (
ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v)
+ shift_ca_audio_hidden_states_a2v
)
vx += (
self.audio_to_video_attn(
vx_scaled,
context=ax_scaled,
pe=v_cross_pe,
k_pe=a_cross_pe,
transformer_options=transformer_options,
)
* gate_out_a2v
)
del gate_out_a2v
del scale_ca_video_hidden_states_a2v,\
shift_ca_video_hidden_states_a2v,\
scale_ca_audio_hidden_states_a2v,\
shift_ca_audio_hidden_states_a2v,\
if run_v2a:
ax_scaled = (
ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a)
+ shift_ca_audio_hidden_states_v2a
)
vx_scaled = (
vx_norm3 * (1 + scale_ca_video_hidden_states_v2a)
+ shift_ca_video_hidden_states_v2a
)
ax += (
self.video_to_audio_attn(
ax_scaled,
context=vx_scaled,
pe=a_cross_pe,
k_pe=v_cross_pe,
transformer_options=transformer_options,
)
* gate_out_v2a
)
del gate_out_v2a
del scale_ca_video_hidden_states_v2a,\
shift_ca_video_hidden_states_v2a,\
scale_ca_audio_hidden_states_v2a,\
shift_ca_audio_hidden_states_v2a
if run_vx:
vshift_mlp, vscale_mlp, vgate_mlp = (
self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, None))
)
vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp
vx += self.ff(vx_scaled) * vgate_mlp
del vshift_mlp, vscale_mlp, vgate_mlp
if run_ax:
ashift_mlp, ascale_mlp, agate_mlp = (
self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, None))
)
ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp
ax += self.audio_ff(ax_scaled) * agate_mlp
del ashift_mlp, ascale_mlp, agate_mlp
return vx, ax
class LTXAVModel(LTXVModel):
"""LTXAV model for audio-video generation."""
def __init__(
self,
in_channels=128,
audio_in_channels=128,
cross_attention_dim=4096,
audio_cross_attention_dim=2048,
attention_head_dim=128,
audio_attention_head_dim=64,
num_attention_heads=32,
audio_num_attention_heads=32,
caption_channels=3840,
num_layers=48,
positional_embedding_theta=10000.0,
positional_embedding_max_pos=[20, 2048, 2048],
audio_positional_embedding_max_pos=[20],
causal_temporal_positioning=False,
vae_scale_factors=(8, 32, 32),
use_middle_indices_grid=False,
timestep_scale_multiplier=1000.0,
av_ca_timestep_scale_multiplier=1.0,
dtype=None,
device=None,
operations=None,
**kwargs,
):
# Store audio-specific parameters
self.audio_in_channels = audio_in_channels
self.audio_cross_attention_dim = audio_cross_attention_dim
self.audio_attention_head_dim = audio_attention_head_dim
self.audio_num_attention_heads = audio_num_attention_heads
self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos
# Calculate audio dimensions
self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim
self.audio_out_channels = audio_in_channels
# Audio-specific constants
self.num_audio_channels = 8
self.audio_frequency_bins = 16
self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier
super().__init__(
in_channels=in_channels,
cross_attention_dim=cross_attention_dim,
attention_head_dim=attention_head_dim,
num_attention_heads=num_attention_heads,
caption_channels=caption_channels,
num_layers=num_layers,
positional_embedding_theta=positional_embedding_theta,
positional_embedding_max_pos=positional_embedding_max_pos,
causal_temporal_positioning=causal_temporal_positioning,
vae_scale_factors=vae_scale_factors,
use_middle_indices_grid=use_middle_indices_grid,
timestep_scale_multiplier=timestep_scale_multiplier,
dtype=dtype,
device=device,
operations=operations,
**kwargs,
)
def _init_model_components(self, device, dtype, **kwargs):
"""Initialize LTXAV-specific components."""
# Audio-specific projections
self.audio_patchify_proj = self.operations.Linear(
self.audio_in_channels, self.audio_inner_dim, bias=True, dtype=dtype, device=device
)
# Audio-specific AdaLN
self.audio_adaln_single = AdaLayerNormSingle(
self.audio_inner_dim,
use_additional_conditions=False,
dtype=dtype,
device=device,
operations=self.operations,
)
num_scale_shift_values = 4
self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
self.inner_dim,
use_additional_conditions=False,
embedding_coefficient=num_scale_shift_values,
dtype=dtype,
device=device,
operations=self.operations,
)
self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle(
self.inner_dim,
use_additional_conditions=False,
embedding_coefficient=1,
dtype=dtype,
device=device,
operations=self.operations,
)
self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle(
self.audio_inner_dim,
use_additional_conditions=False,
embedding_coefficient=num_scale_shift_values,
dtype=dtype,
device=device,
operations=self.operations,
)
self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle(
self.audio_inner_dim,
use_additional_conditions=False,
embedding_coefficient=1,
dtype=dtype,
device=device,
operations=self.operations,
)
# Audio caption projection
self.audio_caption_projection = PixArtAlphaTextProjection(
in_features=self.caption_channels,
hidden_size=self.audio_inner_dim,
dtype=dtype,
device=device,
operations=self.operations,
)
def _init_transformer_blocks(self, device, dtype, **kwargs):
"""Initialize transformer blocks for LTXAV."""
self.transformer_blocks = nn.ModuleList(
[
BasicAVTransformerBlock(
v_dim=self.inner_dim,
a_dim=self.audio_inner_dim,
v_heads=self.num_attention_heads,
a_heads=self.audio_num_attention_heads,
vd_head=self.attention_head_dim,
ad_head=self.audio_attention_head_dim,
v_context_dim=self.cross_attention_dim,
a_context_dim=self.audio_cross_attention_dim,
dtype=dtype,
device=device,
operations=self.operations,
)
for _ in range(self.num_layers)
]
)
def _init_output_components(self, device, dtype):
"""Initialize output components for LTXAV."""
# Video output components
super()._init_output_components(device, dtype)
# Audio output components
self.audio_scale_shift_table = nn.Parameter(
torch.empty(2, self.audio_inner_dim, dtype=dtype, device=device)
)
self.audio_norm_out = self.operations.LayerNorm(
self.audio_inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device
)
self.audio_proj_out = self.operations.Linear(
self.audio_inner_dim, self.audio_out_channels, dtype=dtype, device=device
)
self.a_patchifier = AudioPatchifier(1, start_end=True)
def separate_audio_and_video_latents(self, x, audio_length):
"""Separate audio and video latents from combined input."""
# vx = x[:, : self.in_channels]
# ax = x[:, self.in_channels :]
#
# ax = ax.reshape(ax.shape[0], -1)
# ax = ax[:, : audio_length * self.num_audio_channels * self.audio_frequency_bins]
#
# ax = ax.reshape(
# ax.shape[0], self.num_audio_channels, audio_length, self.audio_frequency_bins
# )
vx = x[0]
ax = x[1] if len(x) > 1 else torch.zeros(
(vx.shape[0], self.num_audio_channels, 0, self.audio_frequency_bins),
device=vx.device, dtype=vx.dtype
)
return vx, ax
def recombine_audio_and_video_latents(self, vx, ax, target_shape=None):
if ax.numel() == 0:
return vx
else:
return [vx, ax]
"""Recombine audio and video latents for output."""
# if ax.device != vx.device or ax.dtype != vx.dtype:
# logging.warning("Audio and video latents are on different devices or dtypes.")
# ax = ax.to(device=vx.device, dtype=vx.dtype)
# logging.warning(f"Audio audio latent moved to device: {ax.device}, dtype: {ax.dtype}")
#
# ax = ax.reshape(ax.shape[0], -1)
# # pad to f x h x w of the video latents
# divisor = vx.shape[-1] * vx.shape[-2] * vx.shape[-3]
# if target_shape is None:
# repetitions = math.ceil(ax.shape[-1] / divisor)
# else:
# repetitions = target_shape[1] - vx.shape[1]
# padded_len = repetitions * divisor
# ax = F.pad(ax, (0, padded_len - ax.shape[-1]))
# ax = ax.reshape(ax.shape[0], -1, vx.shape[-3], vx.shape[-2], vx.shape[-1])
# return torch.cat([vx, ax], dim=1)
def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
"""Process input for LTXAV - separate audio and video, then patchify."""
audio_length = kwargs.get("audio_length", 0)
# Separate audio and video latents
vx, ax = self.separate_audio_and_video_latents(x, audio_length)
[vx, v_pixel_coords, additional_args] = super()._process_input(
vx, keyframe_idxs, denoise_mask, **kwargs
)
ax, a_latent_coords = self.a_patchifier.patchify(ax)
ax = self.audio_patchify_proj(ax)
# additional_args.update({"av_orig_shape": list(x.shape)})
return [vx, ax], [v_pixel_coords, a_latent_coords], additional_args
def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
"""Prepare timestep embeddings."""
# TODO: some code reuse is needed here.
grid_mask = kwargs.get("grid_mask", None)
if grid_mask is not None:
timestep = timestep[:, grid_mask]
timestep_scaled = timestep * self.timestep_scale_multiplier
v_timestep, v_embedded_timestep = self.adaln_single(
timestep_scaled.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
# Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
# Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
orig_shape = kwargs.get("orig_shape")
v_patches_per_frame = None
if orig_shape is not None and len(orig_shape) == 5:
# orig_shape[3] = height, orig_shape[4] = width (in latent space)
v_patches_per_frame = orig_shape[3] * orig_shape[4]
# Reshape to [batch_size, num_tokens, dim] and compress for storage
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame)
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)
# Prepare audio timestep
a_timestep = kwargs.get("a_timestep")
if a_timestep is not None:
a_timestep_scaled = a_timestep * self.timestep_scale_multiplier
a_timestep_flat = a_timestep_scaled.flatten()
timestep_flat = timestep_scaled.flatten()
av_ca_factor = self.av_ca_timestep_scale_multiplier / self.timestep_scale_multiplier
# Cross-attention timesteps - compress these too
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
a_timestep_flat,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
timestep_flat,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
timestep_flat * av_ca_factor,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
a_timestep_flat * av_ca_factor,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
# Compress cross-attention timesteps (only video side, audio is too small to benefit)
cross_av_timestep_ss = [
av_ca_audio_scale_shift_timestep.view(batch_size, -1, av_ca_audio_scale_shift_timestep.shape[-1]),
CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed
CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed
av_ca_v2a_gate_noise_timestep.view(batch_size, -1, av_ca_v2a_gate_noise_timestep.shape[-1]),
]
a_timestep, a_embedded_timestep = self.audio_adaln_single(
a_timestep_flat,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
# Audio timesteps
a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1])
a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1])
else:
a_timestep = timestep_scaled
a_embedded_timestep = kwargs.get("embedded_timestep")
cross_av_timestep_ss = []
return [v_timestep, a_timestep, cross_av_timestep_ss], [
v_embedded_timestep,
a_embedded_timestep,
]
def _prepare_context(self, context, batch_size, x, attention_mask=None):
vx = x[0]
ax = x[1]
v_context, a_context = torch.split(
context, int(context.shape[-1] / 2), len(context.shape) - 1
)
v_context, attention_mask = super()._prepare_context(
v_context, batch_size, vx, attention_mask
)
if self.audio_caption_projection is not None:
a_context = self.audio_caption_projection(a_context)
a_context = a_context.view(batch_size, -1, ax.shape[-1])
return [v_context, a_context], attention_mask
def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype):
v_pixel_coords = pixel_coords[0]
v_pe = super()._prepare_positional_embeddings(v_pixel_coords, frame_rate, x_dtype)
a_latent_coords = pixel_coords[1]
a_pe = self._precompute_freqs_cis(
a_latent_coords,
dim=self.audio_inner_dim,
out_dtype=x_dtype,
max_pos=self.audio_positional_embedding_max_pos,
use_middle_indices_grid=self.use_middle_indices_grid,
num_attention_heads=self.audio_num_attention_heads,
)
# calculate positional embeddings for the middle of the token duration, to use in av cross attention layers.
max_pos = max(
self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0]
)
v_pixel_coords = v_pixel_coords.to(torch.float32)
v_pixel_coords[:, 0] = v_pixel_coords[:, 0] * (1.0 / frame_rate)
av_cross_video_freq_cis = self._precompute_freqs_cis(
v_pixel_coords[:, 0:1, :],
dim=self.audio_cross_attention_dim,
out_dtype=x_dtype,
max_pos=[max_pos],
use_middle_indices_grid=True,
num_attention_heads=self.audio_num_attention_heads,
)
av_cross_audio_freq_cis = self._precompute_freqs_cis(
a_latent_coords[:, 0:1, :],
dim=self.audio_cross_attention_dim,
out_dtype=x_dtype,
max_pos=[max_pos],
use_middle_indices_grid=True,
num_attention_heads=self.audio_num_attention_heads,
)
return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)]
def _process_transformer_blocks(
self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs
):
vx = x[0]
ax = x[1]
v_context = context[0]
a_context = context[1]
v_timestep = timestep[0]
a_timestep = timestep[1]
v_pe, av_cross_video_freq_cis = pe[0]
a_pe, av_cross_audio_freq_cis = pe[1]
(
av_ca_audio_scale_shift_timestep,
av_ca_video_scale_shift_timestep,
av_ca_a2v_gate_noise_timestep,
av_ca_v2a_gate_noise_timestep,
) = timestep[2]
"""Process transformer blocks for LTXAV."""
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
# Process transformer blocks
for i, block in enumerate(self.transformer_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(
args["img"],
v_context=args["v_context"],
a_context=args["a_context"],
attention_mask=args["attention_mask"],
v_timestep=args["v_timestep"],
a_timestep=args["a_timestep"],
v_pe=args["v_pe"],
a_pe=args["a_pe"],
v_cross_pe=args["v_cross_pe"],
a_cross_pe=args["a_cross_pe"],
v_cross_scale_shift_timestep=args["v_cross_scale_shift_timestep"],
a_cross_scale_shift_timestep=args["a_cross_scale_shift_timestep"],
v_cross_gate_timestep=args["v_cross_gate_timestep"],
a_cross_gate_timestep=args["a_cross_gate_timestep"],
transformer_options=args["transformer_options"],
)
return out
out = blocks_replace[("double_block", i)](
{
"img": (vx, ax),
"v_context": v_context,
"a_context": a_context,
"attention_mask": attention_mask,
"v_timestep": v_timestep,
"a_timestep": a_timestep,
"v_pe": v_pe,
"a_pe": a_pe,
"v_cross_pe": av_cross_video_freq_cis,
"a_cross_pe": av_cross_audio_freq_cis,
"v_cross_scale_shift_timestep": av_ca_video_scale_shift_timestep,
"a_cross_scale_shift_timestep": av_ca_audio_scale_shift_timestep,
"v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep,
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
"transformer_options": transformer_options,
},
{"original_block": block_wrap},
)
vx, ax = out["img"]
else:
vx, ax = block(
(vx, ax),
v_context=v_context,
a_context=a_context,
attention_mask=attention_mask,
v_timestep=v_timestep,
a_timestep=a_timestep,
v_pe=v_pe,
a_pe=a_pe,
v_cross_pe=av_cross_video_freq_cis,
a_cross_pe=av_cross_audio_freq_cis,
v_cross_scale_shift_timestep=av_ca_video_scale_shift_timestep,
a_cross_scale_shift_timestep=av_ca_audio_scale_shift_timestep,
v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep,
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
transformer_options=transformer_options,
)
return [vx, ax]
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
vx = x[0]
ax = x[1]
v_embedded_timestep = embedded_timestep[0]
a_embedded_timestep = embedded_timestep[1]
# Expand compressed video timestep if needed
if isinstance(v_embedded_timestep, CompressedTimestep):
v_embedded_timestep = v_embedded_timestep.expand()
vx = super()._process_output(vx, v_embedded_timestep, keyframe_idxs, **kwargs)
# Process audio output
a_scale_shift_values = (
self.audio_scale_shift_table[None, None].to(device=a_embedded_timestep.device, dtype=a_embedded_timestep.dtype)
+ a_embedded_timestep[:, :, None]
)
a_shift, a_scale = a_scale_shift_values[:, :, 0], a_scale_shift_values[:, :, 1]
ax = self.audio_norm_out(ax)
ax = ax * (1 + a_scale) + a_shift
ax = self.audio_proj_out(ax)
# Unpatchify audio
ax = self.a_patchifier.unpatchify(
ax, channels=self.num_audio_channels, freq=self.audio_frequency_bins
)
# Recombine audio and video
original_shape = kwargs.get("av_orig_shape")
return self.recombine_audio_and_video_latents(vx, ax, original_shape)
def forward(
self,
x,
timestep,
context,
attention_mask=None,
frame_rate=25,
transformer_options={},
keyframe_idxs=None,
**kwargs,
):
"""
Forward pass for LTXAV model.
Args:
x: Combined audio-video input tensor
timestep: Tuple of (video_timestep, audio_timestep) or single timestep
context: Context tensor (e.g., text embeddings)
attention_mask: Attention mask tensor
frame_rate: Frame rate for temporal processing
transformer_options: Additional options for transformer blocks
keyframe_idxs: Keyframe indices for temporal processing
**kwargs: Additional keyword arguments including audio_length
Returns:
Combined audio-video output tensor
"""
# Handle timestep format
if isinstance(timestep, (tuple, list)) and len(timestep) == 2:
v_timestep, a_timestep = timestep
kwargs["a_timestep"] = a_timestep
timestep = v_timestep
else:
kwargs["a_timestep"] = timestep
# Call parent forward method
return super().forward(
x,
timestep,
context,
attention_mask,
frame_rate,
transformer_options,
keyframe_idxs,
**kwargs,
)

View File

@ -0,0 +1,305 @@
import math
from typing import Optional
import comfy.ldm.common_dit
import torch
from comfy.ldm.lightricks.model import (
CrossAttention,
FeedForward,
generate_freq_grid_np,
interleaved_freqs_cis,
split_freqs_cis,
)
from torch import nn
class BasicTransformerBlock1D(nn.Module):
r"""
A basic Transformer block.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
upcast_attention (`bool`, *optional*):
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
standardization_norm (`str`, *optional*, defaults to `"layer_norm"`): The type of pre-normalization to use. Can be `"layer_norm"` or `"rms_norm"`.
norm_eps (`float`, *optional*, defaults to 1e-5): Epsilon value for normalization layers.
qk_norm (`str`, *optional*, defaults to None):
Set to 'layer_norm' or `rms_norm` to perform query and key normalization.
final_dropout (`bool` *optional*, defaults to False):
Whether to apply a final dropout after the last feed-forward layer.
ff_inner_dim (`int`, *optional*): Dimension of the inner feed-forward layer. If not provided, defaults to `dim * 4`.
ff_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the feed-forward layer.
attention_out_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the attention output layer.
use_rope (`bool`, *optional*, defaults to `False`): Whether to use Rotary Position Embeddings (RoPE).
ffn_dim_mult (`int`, *optional*, defaults to 4): Multiplier for the inner dimension of the feed-forward layer.
"""
def __init__(
self,
dim,
n_heads,
d_head,
context_dim=None,
attn_precision=None,
dtype=None,
device=None,
operations=None,
):
super().__init__()
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
self.attn1 = CrossAttention(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
context_dim=None,
dtype=dtype,
device=device,
operations=operations,
)
# 3. Feed-forward
self.ff = FeedForward(
dim,
dim_out=dim,
glu=True,
dtype=dtype,
device=device,
operations=operations,
)
def forward(self, hidden_states, attention_mask=None, pe=None) -> torch.FloatTensor:
# Notice that normalization is always applied before the real computation in the following blocks.
# 1. Normalization Before Self-Attention
norm_hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states)
norm_hidden_states = norm_hidden_states.squeeze(1)
# 2. Self-Attention
attn_output = self.attn1(norm_hidden_states, mask=attention_mask, pe=pe)
hidden_states = attn_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
# 3. Normalization before Feed-Forward
norm_hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states)
# 4. Feed-forward
ff_output = self.ff(norm_hidden_states)
hidden_states = ff_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
return hidden_states
class Embeddings1DConnector(nn.Module):
_supports_gradient_checkpointing = True
def __init__(
self,
in_channels=128,
cross_attention_dim=2048,
attention_head_dim=128,
num_attention_heads=30,
num_layers=2,
positional_embedding_theta=10000.0,
positional_embedding_max_pos=[4096],
causal_temporal_positioning=False,
num_learnable_registers: Optional[int] = 128,
dtype=None,
device=None,
operations=None,
split_rope=False,
double_precision_rope=False,
**kwargs,
):
super().__init__()
self.dtype = dtype
self.out_channels = in_channels
self.num_attention_heads = num_attention_heads
self.inner_dim = num_attention_heads * attention_head_dim
self.causal_temporal_positioning = causal_temporal_positioning
self.positional_embedding_theta = positional_embedding_theta
self.positional_embedding_max_pos = positional_embedding_max_pos
self.split_rope = split_rope
self.double_precision_rope = double_precision_rope
self.transformer_1d_blocks = nn.ModuleList(
[
BasicTransformerBlock1D(
self.inner_dim,
num_attention_heads,
attention_head_dim,
context_dim=cross_attention_dim,
dtype=dtype,
device=device,
operations=operations,
)
for _ in range(num_layers)
]
)
inner_dim = num_attention_heads * attention_head_dim
self.num_learnable_registers = num_learnable_registers
if self.num_learnable_registers:
self.learnable_registers = nn.Parameter(
torch.rand(
self.num_learnable_registers, inner_dim, dtype=dtype, device=device
)
* 2.0
- 1.0
)
def get_fractional_positions(self, indices_grid):
fractional_positions = torch.stack(
[
indices_grid[:, i] / self.positional_embedding_max_pos[i]
for i in range(1)
],
dim=-1,
)
return fractional_positions
def precompute_freqs(self, indices_grid, spacing):
source_dtype = indices_grid.dtype
dtype = (
torch.float32
if source_dtype in (torch.bfloat16, torch.float16)
else source_dtype
)
fractional_positions = self.get_fractional_positions(indices_grid)
indices = (
generate_freq_grid_np(
self.positional_embedding_theta,
indices_grid.shape[1],
self.inner_dim,
)
if self.double_precision_rope
else self.generate_freq_grid(spacing, dtype, fractional_positions.device)
).to(device=fractional_positions.device)
if spacing == "exp_2":
freqs = (
(indices * fractional_positions.unsqueeze(-1))
.transpose(-1, -2)
.flatten(2)
)
else:
freqs = (
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
.transpose(-1, -2)
.flatten(2)
)
return freqs
def generate_freq_grid(self, spacing, dtype, device):
dim = self.inner_dim
theta = self.positional_embedding_theta
n_pos_dims = 1
n_elem = 2 * n_pos_dims # 2 for cos and sin e.g. x 3 = 6
start = 1
end = theta
if spacing == "exp":
indices = theta ** (torch.arange(0, dim, n_elem, device="cpu", dtype=torch.float32) / (dim - n_elem))
indices = indices.to(dtype=dtype, device=device)
elif spacing == "exp_2":
indices = 1.0 / theta ** (torch.arange(0, dim, n_elem, device=device) / dim)
indices = indices.to(dtype=dtype)
elif spacing == "linear":
indices = torch.linspace(
start, end, dim // n_elem, device=device, dtype=dtype
)
elif spacing == "sqrt":
indices = torch.linspace(
start**2, end**2, dim // n_elem, device=device, dtype=dtype
).sqrt()
indices = indices * math.pi / 2
return indices
def precompute_freqs_cis(self, indices_grid, spacing="exp"):
dim = self.inner_dim
n_elem = 2 # 2 because of cos and sin
freqs = self.precompute_freqs(indices_grid, spacing)
if self.split_rope:
expected_freqs = dim // 2
current_freqs = freqs.shape[-1]
pad_size = expected_freqs - current_freqs
cos_freq, sin_freq = split_freqs_cis(
freqs, pad_size, self.num_attention_heads
)
else:
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
return cos_freq.to(self.dtype), sin_freq.to(self.dtype), self.split_rope
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
):
"""
The [`Transformer2DModel`] forward method.
Args:
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
Input `hidden_states`.
indices_grid (`torch.LongTensor` of shape `(batch size, 3, num latent pixels)`):
attention_mask ( `torch.Tensor`, *optional*):
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
negative values to the attention scores corresponding to "discard" tokens.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
# 1. Input
if self.num_learnable_registers:
num_registers_duplications = math.ceil(
max(1024, hidden_states.shape[1]) / self.num_learnable_registers
)
learnable_registers = torch.tile(
self.learnable_registers.to(hidden_states), (num_registers_duplications, 1)
)
hidden_states = torch.cat((hidden_states, learnable_registers[hidden_states.shape[1]:].unsqueeze(0).repeat(hidden_states.shape[0], 1, 1)), dim=1)
if attention_mask is not None:
attention_mask = torch.zeros([1, 1, 1, hidden_states.shape[1]], dtype=attention_mask.dtype, device=attention_mask.device)
indices_grid = torch.arange(
hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device
)
indices_grid = indices_grid[None, None, :]
freqs_cis = self.precompute_freqs_cis(indices_grid)
# 2. Blocks
for block_idx, block in enumerate(self.transformer_1d_blocks):
hidden_states = block(
hidden_states, attention_mask=attention_mask, pe=freqs_cis
)
# 3. Output
# if self.output_scale is not None:
# hidden_states = hidden_states / self.output_scale
hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states)
return hidden_states, attention_mask

View File

@ -0,0 +1,292 @@
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
def _rational_for_scale(scale: float) -> Tuple[int, int]:
mapping = {0.75: (3, 4), 1.5: (3, 2), 2.0: (2, 1), 4.0: (4, 1)}
if float(scale) not in mapping:
raise ValueError(
f"Unsupported spatial_scale {scale}. Choose from {list(mapping.keys())}"
)
return mapping[float(scale)]
class PixelShuffleND(nn.Module):
def __init__(self, dims, upscale_factors=(2, 2, 2)):
super().__init__()
assert dims in [1, 2, 3], "dims must be 1, 2, or 3"
self.dims = dims
self.upscale_factors = upscale_factors
def forward(self, x):
if self.dims == 3:
return rearrange(
x,
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
p1=self.upscale_factors[0],
p2=self.upscale_factors[1],
p3=self.upscale_factors[2],
)
elif self.dims == 2:
return rearrange(
x,
"b (c p1 p2) h w -> b c (h p1) (w p2)",
p1=self.upscale_factors[0],
p2=self.upscale_factors[1],
)
elif self.dims == 1:
return rearrange(
x,
"b (c p1) f h w -> b c (f p1) h w",
p1=self.upscale_factors[0],
)
class BlurDownsample(nn.Module):
"""
Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel.
Applies only on H,W. Works for dims=2 or dims=3 (per-frame).
"""
def __init__(self, dims: int, stride: int):
super().__init__()
assert dims in (2, 3)
assert stride >= 1 and isinstance(stride, int)
self.dims = dims
self.stride = stride
# 5x5 separable binomial kernel [1,4,6,4,1] (outer product), normalized
k = torch.tensor([1.0, 4.0, 6.0, 4.0, 1.0])
k2d = k[:, None] @ k[None, :]
k2d = (k2d / k2d.sum()).float() # shape (5,5)
self.register_buffer("kernel", k2d[None, None, :, :]) # (1,1,5,5)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.stride == 1:
return x
def _apply_2d(x2d: torch.Tensor) -> torch.Tensor:
# x2d: (B, C, H, W)
B, C, H, W = x2d.shape
weight = self.kernel.expand(C, 1, 5, 5) # depthwise
x2d = F.conv2d(
x2d, weight=weight, bias=None, stride=self.stride, padding=2, groups=C
)
return x2d
if self.dims == 2:
return _apply_2d(x)
else:
# dims == 3: apply per-frame on H,W
b, c, f, h, w = x.shape
x = rearrange(x, "b c f h w -> (b f) c h w")
x = _apply_2d(x)
h2, w2 = x.shape[-2:]
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f, h=h2, w=w2)
return x
class SpatialRationalResampler(nn.Module):
"""
Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased
downsample by 'den' using fixed blur + stride. Operates on H,W only.
For dims==3, work per-frame for spatial scaling (temporal axis untouched).
"""
def __init__(self, mid_channels: int, scale: float):
super().__init__()
self.scale = float(scale)
self.num, self.den = _rational_for_scale(self.scale)
self.conv = nn.Conv2d(
mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1
)
self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num))
self.blur_down = BlurDownsample(dims=2, stride=self.den)
def forward(self, x: torch.Tensor) -> torch.Tensor:
b, c, f, h, w = x.shape
x = rearrange(x, "b c f h w -> (b f) c h w")
x = self.conv(x)
x = self.pixel_shuffle(x)
x = self.blur_down(x)
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
return x
class ResBlock(nn.Module):
def __init__(
self, channels: int, mid_channels: Optional[int] = None, dims: int = 3
):
super().__init__()
if mid_channels is None:
mid_channels = channels
Conv = nn.Conv2d if dims == 2 else nn.Conv3d
self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1)
self.norm1 = nn.GroupNorm(32, mid_channels)
self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1)
self.norm2 = nn.GroupNorm(32, channels)
self.activation = nn.SiLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = self.conv1(x)
x = self.norm1(x)
x = self.activation(x)
x = self.conv2(x)
x = self.norm2(x)
x = self.activation(x + residual)
return x
class LatentUpsampler(nn.Module):
"""
Model to spatially upsample VAE latents.
Args:
in_channels (`int`): Number of channels in the input latent
mid_channels (`int`): Number of channels in the middle layers
num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling)
dims (`int`): Number of dimensions for convolutions (2 or 3)
spatial_upsample (`bool`): Whether to spatially upsample the latent
temporal_upsample (`bool`): Whether to temporally upsample the latent
"""
def __init__(
self,
in_channels: int = 128,
mid_channels: int = 512,
num_blocks_per_stage: int = 4,
dims: int = 3,
spatial_upsample: bool = True,
temporal_upsample: bool = False,
spatial_scale: float = 2.0,
rational_resampler: bool = False,
):
super().__init__()
self.in_channels = in_channels
self.mid_channels = mid_channels
self.num_blocks_per_stage = num_blocks_per_stage
self.dims = dims
self.spatial_upsample = spatial_upsample
self.temporal_upsample = temporal_upsample
self.spatial_scale = float(spatial_scale)
self.rational_resampler = rational_resampler
Conv = nn.Conv2d if dims == 2 else nn.Conv3d
self.initial_conv = Conv(in_channels, mid_channels, kernel_size=3, padding=1)
self.initial_norm = nn.GroupNorm(32, mid_channels)
self.initial_activation = nn.SiLU()
self.res_blocks = nn.ModuleList(
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
)
if spatial_upsample and temporal_upsample:
self.upsampler = nn.Sequential(
nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(3),
)
elif spatial_upsample:
if rational_resampler:
self.upsampler = SpatialRationalResampler(
mid_channels=mid_channels, scale=self.spatial_scale
)
else:
self.upsampler = nn.Sequential(
nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(2),
)
elif temporal_upsample:
self.upsampler = nn.Sequential(
nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(1),
)
else:
raise ValueError(
"Either spatial_upsample or temporal_upsample must be True"
)
self.post_upsample_res_blocks = nn.ModuleList(
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
)
self.final_conv = Conv(mid_channels, in_channels, kernel_size=3, padding=1)
def forward(self, latent: torch.Tensor) -> torch.Tensor:
b, c, f, h, w = latent.shape
if self.dims == 2:
x = rearrange(latent, "b c f h w -> (b f) c h w")
x = self.initial_conv(x)
x = self.initial_norm(x)
x = self.initial_activation(x)
for block in self.res_blocks:
x = block(x)
x = self.upsampler(x)
for block in self.post_upsample_res_blocks:
x = block(x)
x = self.final_conv(x)
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
else:
x = self.initial_conv(latent)
x = self.initial_norm(x)
x = self.initial_activation(x)
for block in self.res_blocks:
x = block(x)
if self.temporal_upsample:
x = self.upsampler(x)
x = x[:, :, 1:, :, :]
else:
if isinstance(self.upsampler, SpatialRationalResampler):
x = self.upsampler(x)
else:
x = rearrange(x, "b c f h w -> (b f) c h w")
x = self.upsampler(x)
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
for block in self.post_upsample_res_blocks:
x = block(x)
x = self.final_conv(x)
return x
@classmethod
def from_config(cls, config):
return cls(
in_channels=config.get("in_channels", 4),
mid_channels=config.get("mid_channels", 128),
num_blocks_per_stage=config.get("num_blocks_per_stage", 4),
dims=config.get("dims", 2),
spatial_upsample=config.get("spatial_upsample", True),
temporal_upsample=config.get("temporal_upsample", False),
spatial_scale=config.get("spatial_scale", 2.0),
rational_resampler=config.get("rational_resampler", False),
)
def config(self):
return {
"_class_name": "LatentUpsampler",
"in_channels": self.in_channels,
"mid_channels": self.mid_channels,
"num_blocks_per_stage": self.num_blocks_per_stage,
"dims": self.dims,
"spatial_upsample": self.spatial_upsample,
"temporal_upsample": self.temporal_upsample,
"spatial_scale": self.spatial_scale,
"rational_resampler": self.rational_resampler,
}

View File

@ -1,13 +1,47 @@
from abc import ABC, abstractmethod
from enum import Enum
import functools
import math
from typing import Dict, Optional, Tuple
from einops import rearrange
import numpy as np
import torch
from torch import nn
import comfy.patcher_extension
import comfy.ldm.modules.attention
import comfy.ldm.common_dit
import math
from typing import Dict, Optional, Tuple
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
from comfy.ldm.flux.math import apply_rope1
def _log_base(x, base):
return np.log(x) / np.log(base)
class LTXRopeType(str, Enum):
INTERLEAVED = "interleaved"
SPLIT = "split"
KEY = "rope_type"
@classmethod
def from_dict(cls, kwargs, default=None):
if default is None:
default = cls.INTERLEAVED
return cls(kwargs.get(cls.KEY, default))
class LTXFrequenciesPrecision(str, Enum):
FLOAT32 = "float32"
FLOAT64 = "float64"
KEY = "frequencies_precision"
@classmethod
def from_dict(cls, kwargs, default=None):
if default is None:
default = cls.FLOAT32
return cls(kwargs.get(cls.KEY, default))
def get_timestep_embedding(
timesteps: torch.Tensor,
@ -39,9 +73,7 @@ def get_timestep_embedding(
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
)
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
@ -73,7 +105,9 @@ class TimestepEmbedding(nn.Module):
post_act_fn: Optional[str] = None,
cond_proj_dim=None,
sample_proj_bias=True,
dtype=None, device=None, operations=None,
dtype=None,
device=None,
operations=None,
):
super().__init__()
@ -90,7 +124,9 @@ class TimestepEmbedding(nn.Module):
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device)
self.linear_2 = operations.Linear(
time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device
)
if post_act_fn is None:
self.post_act = None
@ -139,12 +175,22 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
"""
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
def __init__(
self,
embedding_dim,
size_emb_dim,
use_additional_conditions: bool = False,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.outdim = size_emb_dim
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations)
self.timestep_embedder = TimestepEmbedding(
in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations
)
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
timesteps_proj = self.time_proj(timestep)
@ -163,15 +209,22 @@ class AdaLayerNormSingle(nn.Module):
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
"""
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
def __init__(
self, embedding_dim: int, embedding_coefficient: int = 6, use_additional_conditions: bool = False, dtype=None, device=None, operations=None
):
super().__init__()
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions, dtype=dtype, device=device, operations=operations
embedding_dim,
size_emb_dim=embedding_dim // 3,
use_additional_conditions=use_additional_conditions,
dtype=dtype,
device=device,
operations=operations,
)
self.silu = nn.SiLU()
self.linear = operations.Linear(embedding_dim, 6 * embedding_dim, bias=True, dtype=dtype, device=device)
self.linear = operations.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True, dtype=dtype, device=device)
def forward(
self,
@ -185,6 +238,7 @@ class AdaLayerNormSingle(nn.Module):
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
return self.linear(self.silu(embedded_timestep)), embedded_timestep
class PixArtAlphaTextProjection(nn.Module):
"""
Projects caption embeddings. Also handles dropout for classifier-free guidance.
@ -192,18 +246,24 @@ class PixArtAlphaTextProjection(nn.Module):
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
"""
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None):
def __init__(
self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None
):
super().__init__()
if out_features is None:
out_features = hidden_size
self.linear_1 = operations.Linear(in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device)
self.linear_1 = operations.Linear(
in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device
)
if act_fn == "gelu_tanh":
self.act_1 = nn.GELU(approximate="tanh")
elif act_fn == "silu":
self.act_1 = nn.SiLU()
else:
raise ValueError(f"Unknown activation function: {act_fn}")
self.linear_2 = operations.Linear(in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device)
self.linear_2 = operations.Linear(
in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device
)
def forward(self, caption):
hidden_states = self.linear_1(caption)
@ -222,23 +282,68 @@ class GELU_approx(nn.Module):
class FeedForward(nn.Module):
def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=None):
def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0.0, dtype=None, device=None, operations=None):
super().__init__()
inner_dim = int(dim * mult)
project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
project_in, nn.Dropout(dropout), operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
)
def forward(self, x):
return self.net(x)
def apply_rotary_emb(input_tensor, freqs_cis):
cos_freqs, sin_freqs = freqs_cis[0], freqs_cis[1]
split_pe = freqs_cis[2] if len(freqs_cis) > 2 else False
return (
apply_split_rotary_emb(input_tensor, cos_freqs, sin_freqs)
if split_pe else
apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs)
)
def apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs): # TODO: remove duplicate funcs and pick the best/fastest one
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
t1, t2 = t_dup.unbind(dim=-1)
t_dup = torch.stack((-t2, t1), dim=-1)
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
return out
def apply_split_rotary_emb(input_tensor, cos, sin):
needs_reshape = False
if input_tensor.ndim != 4 and cos.ndim == 4:
B, H, T, _ = cos.shape
input_tensor = input_tensor.reshape(B, T, H, -1).swapaxes(1, 2)
needs_reshape = True
split_input = rearrange(input_tensor, "... (d r) -> ... d r", d=2)
first_half_input = split_input[..., :1, :]
second_half_input = split_input[..., 1:, :]
output = split_input * cos.unsqueeze(-2)
first_half_output = output[..., :1, :]
second_half_output = output[..., 1:, :]
first_half_output.addcmul_(-sin.unsqueeze(-2), second_half_input)
second_half_output.addcmul_(sin.unsqueeze(-2), first_half_input)
output = rearrange(output, "... d r -> ... (d r)")
return output.swapaxes(1, 2).reshape(B, T, -1) if needs_reshape else output
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
def __init__(
self,
query_dim,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.0,
attn_precision=None,
dtype=None,
device=None,
operations=None,
):
super().__init__()
inner_dim = dim_head * heads
context_dim = query_dim if context_dim is None else context_dim
@ -254,9 +359,11 @@ class CrossAttention(nn.Module):
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
self.to_out = nn.Sequential(
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)
)
def forward(self, x, context=None, mask=None, pe=None, transformer_options={}):
def forward(self, x, context=None, mask=None, pe=None, k_pe=None, transformer_options={}):
q = self.to_q(x)
context = x if context is None else context
k = self.to_k(context)
@ -266,8 +373,8 @@ class CrossAttention(nn.Module):
k = self.k_norm(k)
if pe is not None:
q = apply_rope1(q.unsqueeze(1), pe).squeeze(1)
k = apply_rope1(k.unsqueeze(1), pe).squeeze(1)
q = apply_rotary_emb(q, pe)
k = apply_rotary_emb(k, pe if k_pe is None else k_pe)
if mask is None:
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
@ -277,14 +384,34 @@ class CrossAttention(nn.Module):
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None):
def __init__(
self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None
):
super().__init__()
self.attn_precision = attn_precision
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, context_dim=None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
self.attn1 = CrossAttention(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
context_dim=None,
attn_precision=self.attn_precision,
dtype=dtype,
device=device,
operations=operations,
)
self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
self.attn2 = CrossAttention(
query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
attn_precision=self.attn_precision,
dtype=dtype,
device=device,
operations=operations,
)
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
@ -306,116 +433,446 @@ class BasicTransformerBlock(nn.Module):
return x
def get_fractional_positions(indices_grid, max_pos):
n_pos_dims = indices_grid.shape[1]
assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})'
fractional_positions = torch.stack(
[
indices_grid[:, i] / max_pos[i]
for i in range(3)
],
dim=-1,
[indices_grid[:, i] / max_pos[i] for i in range(n_pos_dims)],
axis=-1,
)
return fractional_positions
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
dtype = torch.float32
device = indices_grid.device
@functools.lru_cache(maxsize=5)
def generate_freq_grid_np(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, _ = None):
theta = positional_embedding_theta
start = 1
end = theta
n_elem = 2 * positional_embedding_max_pos_count
pow_indices = np.power(
theta,
np.linspace(
_log_base(start, theta),
_log_base(end, theta),
inner_dim // n_elem,
dtype=np.float64,
),
)
return torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32)
def generate_freq_grid_pytorch(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, device):
theta = positional_embedding_theta
start = 1
end = theta
n_elem = 2 * positional_embedding_max_pos_count
indices = theta ** (
torch.linspace(
math.log(start, theta),
math.log(end, theta),
inner_dim // n_elem,
device=device,
dtype=torch.float32,
)
)
indices = indices.to(dtype=torch.float32)
indices = indices * math.pi / 2
return indices
def generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid):
if use_middle_indices_grid:
assert(len(indices_grid.shape) == 4 and indices_grid.shape[-1] ==2)
indices_grid_start, indices_grid_end = indices_grid[..., 0], indices_grid[..., 1]
indices_grid = (indices_grid_start + indices_grid_end) / 2.0
elif len(indices_grid.shape) == 4:
indices_grid = indices_grid[..., 0]
# Get fractional positions and compute frequency indices
fractional_positions = get_fractional_positions(indices_grid, max_pos)
indices = theta ** torch.linspace(0, 1, dim // 6, device=device, dtype=dtype) * math.pi / 2
indices = indices.to(device=fractional_positions.device)
# Compute frequencies and apply cos/sin
freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2)
cos_vals = freqs.cos().repeat_interleave(2, dim=-1)
sin_vals = freqs.sin().repeat_interleave(2, dim=-1)
freqs = (
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
.transpose(-1, -2)
.flatten(2)
)
return freqs
# Pad if dim is not divisible by 6
if dim % 6 != 0:
padding_size = dim % 6
cos_vals = torch.cat([torch.ones_like(cos_vals[:, :, :padding_size]), cos_vals], dim=-1)
sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1)
def interleaved_freqs_cis(freqs, pad_size):
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
if pad_size != 0:
cos_padding = torch.ones_like(cos_freq[:, :, : pad_size])
sin_padding = torch.zeros_like(cos_freq[:, :, : pad_size])
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
return cos_freq, sin_freq
# Reshape and extract one value per pair (since repeat_interleave duplicates each value)
cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
def split_freqs_cis(freqs, pad_size, num_attention_heads):
cos_freq = freqs.cos()
sin_freq = freqs.sin()
# Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension
freqs_cis = torch.stack([
torch.stack([cos_vals, -sin_vals], dim=-1),
torch.stack([sin_vals, cos_vals], dim=-1)
], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2]
if pad_size != 0:
cos_padding = torch.ones_like(cos_freq[:, :, :pad_size])
sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size])
return freqs_cis
cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1)
sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1)
# Reshape freqs to be compatible with multi-head attention
B , T, half_HD = cos_freq.shape
class LTXVModel(torch.nn.Module):
def __init__(self,
in_channels=128,
cross_attention_dim=2048,
attention_head_dim=64,
num_attention_heads=32,
cos_freq = cos_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads)
sin_freq = sin_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads)
caption_channels=4096,
num_layers=28,
cos_freq = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2)
sin_freq = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2)
return cos_freq, sin_freq
class LTXBaseModel(torch.nn.Module, ABC):
"""
Abstract base class for LTX models (Lightricks Transformer models).
positional_embedding_theta=10000.0,
positional_embedding_max_pos=[20, 2048, 2048],
causal_temporal_positioning=False,
vae_scale_factors=(8, 32, 32),
dtype=None, device=None, operations=None, **kwargs):
This class defines the common interface and shared functionality for all LTX models,
including LTXV (video) and LTXAV (audio-video) variants.
"""
def __init__(
self,
in_channels: int,
cross_attention_dim: int,
attention_head_dim: int,
num_attention_heads: int,
caption_channels: int,
num_layers: int,
positional_embedding_theta: float = 10000.0,
positional_embedding_max_pos: list = [20, 2048, 2048],
causal_temporal_positioning: bool = False,
vae_scale_factors: tuple = (8, 32, 32),
use_middle_indices_grid=False,
timestep_scale_multiplier = 1000.0,
dtype=None,
device=None,
operations=None,
**kwargs,
):
super().__init__()
self.generator = None
self.vae_scale_factors = vae_scale_factors
self.use_middle_indices_grid = use_middle_indices_grid
self.dtype = dtype
self.out_channels = in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.in_channels = in_channels
self.cross_attention_dim = cross_attention_dim
self.attention_head_dim = attention_head_dim
self.num_attention_heads = num_attention_heads
self.caption_channels = caption_channels
self.num_layers = num_layers
self.positional_embedding_theta = positional_embedding_theta
self.positional_embedding_max_pos = positional_embedding_max_pos
self.split_positional_embedding = LTXRopeType.from_dict(kwargs)
self.freq_grid_generator = (
generate_freq_grid_np if LTXFrequenciesPrecision.from_dict(kwargs) == LTXFrequenciesPrecision.FLOAT64
else generate_freq_grid_pytorch
)
self.causal_temporal_positioning = causal_temporal_positioning
self.operations = operations
self.timestep_scale_multiplier = timestep_scale_multiplier
self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
# Common dimensions
self.inner_dim = num_attention_heads * attention_head_dim
self.out_channels = in_channels
# Initialize common components
self._init_common_components(device, dtype)
# Initialize model-specific components
self._init_model_components(device, dtype, **kwargs)
# Initialize transformer blocks
self._init_transformer_blocks(device, dtype, **kwargs)
# Initialize output components
self._init_output_components(device, dtype)
def _init_common_components(self, device, dtype):
"""Initialize components common to all LTX models
- patchify_proj: Linear projection for patchifying input
- adaln_single: AdaLN layer for timestep embedding
- caption_projection: Linear projection for caption embedding
"""
self.patchify_proj = self.operations.Linear(
self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device
)
self.adaln_single = AdaLayerNormSingle(
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=operations
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
)
# self.adaln_single.linear = operations.Linear(self.inner_dim, 4 * self.inner_dim, bias=True, dtype=dtype, device=device)
self.caption_projection = PixArtAlphaTextProjection(
in_features=caption_channels, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations
in_features=self.caption_channels,
hidden_size=self.inner_dim,
dtype=dtype,
device=device,
operations=self.operations,
)
@abstractmethod
def _init_model_components(self, device, dtype, **kwargs):
"""Initialize model-specific components. Must be implemented by subclasses."""
pass
@abstractmethod
def _init_transformer_blocks(self, device, dtype, **kwargs):
"""Initialize transformer blocks. Must be implemented by subclasses."""
pass
@abstractmethod
def _init_output_components(self, device, dtype):
"""Initialize output components. Must be implemented by subclasses."""
pass
@abstractmethod
def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
"""Process input data. Must be implemented by subclasses."""
pass
@abstractmethod
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, **kwargs):
"""Process transformer blocks. Must be implemented by subclasses."""
pass
@abstractmethod
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
"""Process output data. Must be implemented by subclasses."""
pass
def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
"""Prepare timestep embeddings."""
grid_mask = kwargs.get("grid_mask", None)
if grid_mask is not None:
timestep = timestep[:, grid_mask]
timestep = timestep * self.timestep_scale_multiplier
timestep, embedded_timestep = self.adaln_single(
timestep.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
# Second dimension is 1 or number of tokens (if timestep_per_token)
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])
return timestep, embedded_timestep
def _prepare_context(self, context, batch_size, x, attention_mask=None):
"""Prepare context for transformer blocks."""
if self.caption_projection is not None:
context = self.caption_projection(context)
context = context.view(batch_size, -1, x.shape[-1])
return context, attention_mask
def _precompute_freqs_cis(
self,
indices_grid,
dim,
out_dtype,
theta=10000.0,
max_pos=[20, 2048, 2048],
use_middle_indices_grid=False,
num_attention_heads=32,
):
split_mode = self.split_positional_embedding == LTXRopeType.SPLIT
indices = self.freq_grid_generator(theta, indices_grid.shape[1], dim, indices_grid.device)
freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid)
if split_mode:
expected_freqs = dim // 2
current_freqs = freqs.shape[-1]
pad_size = expected_freqs - current_freqs
cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads)
else:
# 2 because of cos and sin by 3 for (t, x, y), 1 for temporal only
n_elem = 2 * indices_grid.shape[1]
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
return cos_freq.to(out_dtype), sin_freq.to(out_dtype), split_mode
def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype):
"""Prepare positional embeddings."""
fractional_coords = pixel_coords.to(torch.float32)
fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
pe = self._precompute_freqs_cis(
fractional_coords,
dim=self.inner_dim,
out_dtype=x_dtype,
max_pos=self.positional_embedding_max_pos,
use_middle_indices_grid=self.use_middle_indices_grid,
num_attention_heads=self.num_attention_heads,
)
return pe
def _prepare_attention_mask(self, attention_mask, x_dtype):
"""Prepare attention mask."""
if attention_mask is not None and not torch.is_floating_point(attention_mask):
attention_mask = (attention_mask - 1).to(x_dtype).reshape(
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
) * torch.finfo(x_dtype).max
return attention_mask
def forward(
self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs
):
"""
Forward pass for LTX models.
Args:
x: Input tensor
timestep: Timestep tensor
context: Context tensor (e.g., text embeddings)
attention_mask: Attention mask tensor
frame_rate: Frame rate for temporal processing
transformer_options: Additional options for transformer blocks
keyframe_idxs: Keyframe indices for temporal processing
**kwargs: Additional keyword arguments
Returns:
Processed output tensor
"""
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(
comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options
),
).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, denoise_mask=denoise_mask, **kwargs)
def _forward(
self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs
):
"""
Internal forward pass for LTX models.
Args:
x: Input tensor
timestep: Timestep tensor
context: Context tensor (e.g., text embeddings)
attention_mask: Attention mask tensor
frame_rate: Frame rate for temporal processing
transformer_options: Additional options for transformer blocks
keyframe_idxs: Keyframe indices for temporal processing
**kwargs: Additional keyword arguments
Returns:
Processed output tensor
"""
if isinstance(x, list):
input_dtype = x[0].dtype
batch_size = x[0].shape[0]
else:
input_dtype = x.dtype
batch_size = x.shape[0]
# Process input
merged_args = {**transformer_options, **kwargs}
x, pixel_coords, additional_args = self._process_input(x, keyframe_idxs, denoise_mask, **merged_args)
merged_args.update(additional_args)
# Prepare timestep and context
timestep, embedded_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args)
context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask)
# Prepare attention mask and positional embeddings
attention_mask = self._prepare_attention_mask(attention_mask, input_dtype)
pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype)
# Process transformer blocks
x = self._process_transformer_blocks(
x, context, attention_mask, timestep, pe, transformer_options=transformer_options, **merged_args
)
# Process output
x = self._process_output(x, embedded_timestep, keyframe_idxs, **merged_args)
return x
class LTXVModel(LTXBaseModel):
"""LTXV model for video generation."""
def __init__(
self,
in_channels=128,
cross_attention_dim=2048,
attention_head_dim=64,
num_attention_heads=32,
caption_channels=4096,
num_layers=28,
positional_embedding_theta=10000.0,
positional_embedding_max_pos=[20, 2048, 2048],
causal_temporal_positioning=False,
vae_scale_factors=(8, 32, 32),
use_middle_indices_grid=False,
timestep_scale_multiplier = 1000.0,
dtype=None,
device=None,
operations=None,
**kwargs,
):
super().__init__(
in_channels=in_channels,
cross_attention_dim=cross_attention_dim,
attention_head_dim=attention_head_dim,
num_attention_heads=num_attention_heads,
caption_channels=caption_channels,
num_layers=num_layers,
positional_embedding_theta=positional_embedding_theta,
positional_embedding_max_pos=positional_embedding_max_pos,
causal_temporal_positioning=causal_temporal_positioning,
vae_scale_factors=vae_scale_factors,
use_middle_indices_grid=use_middle_indices_grid,
timestep_scale_multiplier=timestep_scale_multiplier,
dtype=dtype,
device=device,
operations=operations,
**kwargs,
)
def _init_model_components(self, device, dtype, **kwargs):
"""Initialize LTXV-specific components."""
# No additional components needed for LTXV beyond base class
pass
def _init_transformer_blocks(self, device, dtype, **kwargs):
"""Initialize transformer blocks for LTXV."""
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
self.inner_dim,
num_attention_heads,
attention_head_dim,
context_dim=cross_attention_dim,
# attn_precision=attn_precision,
dtype=dtype, device=device, operations=operations
self.num_attention_heads,
self.attention_head_dim,
context_dim=self.cross_attention_dim,
dtype=dtype,
device=device,
operations=self.operations,
)
for d in range(num_layers)
for _ in range(self.num_layers)
]
)
def _init_output_components(self, device, dtype):
"""Initialize output components for LTXV."""
self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device))
self.norm_out = operations.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.proj_out = operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device)
self.patchifier = SymmetricPatchifier(1)
def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, **kwargs)
def _forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
orig_shape = list(x.shape)
self.norm_out = self.operations.LayerNorm(
self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device
)
self.proj_out = self.operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device)
self.patchifier = SymmetricPatchifier(1, start_end=True)
def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
"""Process input for LTXV."""
additional_args = {"orig_shape": list(x.shape)}
x, latent_coords = self.patchifier.patchify(x)
pixel_coords = latent_to_pixel_coords(
latent_coords=latent_coords,
@ -423,44 +880,30 @@ class LTXVModel(torch.nn.Module):
causal_fix=self.causal_temporal_positioning,
)
grid_mask = None
if keyframe_idxs is not None:
pixel_coords[:, :, -keyframe_idxs.shape[2]:] = keyframe_idxs
additional_args.update({ "orig_patchified_shape": list(x.shape)})
denoise_mask = self.patchifier.patchify(denoise_mask)[0]
grid_mask = ~torch.any(denoise_mask < 0, dim=-1)[0]
additional_args.update({"grid_mask": grid_mask})
x = x[:, grid_mask, :]
pixel_coords = pixel_coords[:, :, grid_mask, ...]
fractional_coords = pixel_coords.to(torch.float32)
fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:]
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
x = self.patchify_proj(x)
timestep = timestep * 1000.0
if attention_mask is not None and not torch.is_floating_point(attention_mask):
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
pe = precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype)
batch_size = x.shape[0]
timestep, embedded_timestep = self.adaln_single(
timestep.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=x.dtype,
)
# Second dimension is 1 or number of tokens (if timestep_per_token)
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
embedded_timestep = embedded_timestep.view(
batch_size, -1, embedded_timestep.shape[-1]
)
# 2. Blocks
if self.caption_projection is not None:
batch_size = x.shape[0]
context = self.caption_projection(context)
context = context.view(
batch_size, -1, x.shape[-1]
)
return x, pixel_coords, additional_args
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs):
"""Process transformer blocks for LTXV."""
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.transformer_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"])
@ -478,16 +921,28 @@ class LTXVModel(torch.nn.Module):
transformer_options=transformer_options,
)
# 3. Output
return x
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
"""Process output for LTXV."""
# Apply scale-shift modulation
scale_shift_values = (
self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
)
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
x = self.norm_out(x)
# Modulation
x = torch.addcmul(x, x, scale).add_(shift)
x = x * (1 + scale) + shift
x = self.proj_out(x)
if keyframe_idxs is not None:
grid_mask = kwargs["grid_mask"]
orig_patchified_shape = kwargs["orig_patchified_shape"]
full_x = torch.zeros(orig_patchified_shape, dtype=x.dtype, device=x.device)
full_x[:, grid_mask, :] = x
x = full_x
# Unpatchify to restore original dimensions
orig_shape = kwargs["orig_shape"]
x = self.patchifier.unpatchify(
latents=x,
output_height=orig_shape[3],

View File

@ -21,20 +21,23 @@ def latent_to_pixel_coords(
Returns:
Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates.
"""
shape = [1] * latent_coords.ndim
shape[1] = -1
pixel_coords = (
latent_coords
* torch.tensor(scale_factors, device=latent_coords.device)[None, :, None]
* torch.tensor(scale_factors, device=latent_coords.device).view(*shape)
)
if causal_fix:
# Fix temporal scale for first frame to 1 due to causality
pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0)
pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0)
return pixel_coords
class Patchifier(ABC):
def __init__(self, patch_size: int):
def __init__(self, patch_size: int, start_end: bool=False):
super().__init__()
self._patch_size = (1, patch_size, patch_size)
self.start_end = start_end
@abstractmethod
def patchify(
@ -71,11 +74,23 @@ class Patchifier(ABC):
torch.arange(0, latent_width, self._patch_size[2], device=device),
indexing="ij",
)
latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
latent_coords = rearrange(
latent_coords, "b c f h w -> b c (f h w)", b=batch_size
latent_sample_coords_start = torch.stack(latent_sample_coords, dim=0)
delta = torch.tensor(self._patch_size, device=latent_sample_coords_start.device, dtype=latent_sample_coords_start.dtype)[:, None, None, None]
latent_sample_coords_end = latent_sample_coords_start + delta
latent_sample_coords_start = latent_sample_coords_start.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
latent_sample_coords_start = rearrange(
latent_sample_coords_start, "b c f h w -> b c (f h w)", b=batch_size
)
if self.start_end:
latent_sample_coords_end = latent_sample_coords_end.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
latent_sample_coords_end = rearrange(
latent_sample_coords_end, "b c f h w -> b c (f h w)", b=batch_size
)
latent_coords = torch.stack((latent_sample_coords_start, latent_sample_coords_end), dim=-1)
else:
latent_coords = latent_sample_coords_start
return latent_coords
@ -115,3 +130,61 @@ class SymmetricPatchifier(Patchifier):
q=self._patch_size[2],
)
return latents
class AudioPatchifier(Patchifier):
def __init__(self, patch_size: int,
sample_rate=16000,
hop_length=160,
audio_latent_downsample_factor=4,
is_causal=True,
start_end=False,
shift = 0
):
super().__init__(patch_size, start_end=start_end)
self.hop_length = hop_length
self.sample_rate = sample_rate
self.audio_latent_downsample_factor = audio_latent_downsample_factor
self.is_causal = is_causal
self.shift = shift
def copy_with_shift(self, shift):
return AudioPatchifier(
self.patch_size, self.sample_rate, self.hop_length, self.audio_latent_downsample_factor,
self.is_causal, self.start_end, shift
)
def _get_audio_latent_time_in_sec(self, start_latent, end_latent: int, dtype: torch.dtype, device=torch.device):
audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device)
audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor
if self.is_causal:
audio_mel_frame = (audio_mel_frame + 1 - self.audio_latent_downsample_factor).clip(min=0)
return audio_mel_frame * self.hop_length / self.sample_rate
def patchify(self, audio_latents: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# audio_latents: (batch, channels, time, freq)
b, _, t, _ = audio_latents.shape
audio_latents = rearrange(
audio_latents,
"b c t f -> b t (c f)",
)
audio_latents_start_timings = self._get_audio_latent_time_in_sec(self.shift, t + self.shift, torch.float32, audio_latents.device)
audio_latents_start_timings = audio_latents_start_timings.unsqueeze(0).expand(b, -1).unsqueeze(1)
if self.start_end:
audio_latents_end_timings = self._get_audio_latent_time_in_sec(self.shift + 1, t + self.shift + 1, torch.float32, audio_latents.device)
audio_latents_end_timings = audio_latents_end_timings.unsqueeze(0).expand(b, -1).unsqueeze(1)
audio_latents_timings = torch.stack([audio_latents_start_timings, audio_latents_end_timings], dim=-1)
else:
audio_latents_timings = audio_latents_start_timings
return audio_latents, audio_latents_timings
def unpatchify(self, audio_latents: torch.Tensor, channels: int, freq: int) -> torch.Tensor:
# audio_latents: (batch, time, freq * channels)
audio_latents = rearrange(
audio_latents, "b t (c f) -> b c t f", c=channels, f=freq
)
return audio_latents

View File

@ -0,0 +1,286 @@
import json
from dataclasses import dataclass
import math
import torch
import torchaudio
import comfy.model_management
import comfy.model_patcher
import comfy.utils as utils
from comfy.ldm.mmaudio.vae.distributions import DiagonalGaussianDistribution
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
from comfy.ldm.lightricks.vae.causal_audio_autoencoder import (
CausalityAxis,
CausalAudioAutoencoder,
)
from comfy.ldm.lightricks.vocoders.vocoder import Vocoder
LATENT_DOWNSAMPLE_FACTOR = 4
@dataclass(frozen=True)
class AudioVAEComponentConfig:
"""Container for model component configuration extracted from metadata."""
autoencoder: dict
vocoder: dict
@classmethod
def from_metadata(cls, metadata: dict) -> "AudioVAEComponentConfig":
assert metadata is not None and "config" in metadata, "Metadata is required for audio VAE"
raw_config = metadata["config"]
if isinstance(raw_config, str):
parsed_config = json.loads(raw_config)
else:
parsed_config = raw_config
audio_config = parsed_config.get("audio_vae")
vocoder_config = parsed_config.get("vocoder")
assert audio_config is not None, "Audio VAE config is required for audio VAE"
assert vocoder_config is not None, "Vocoder config is required for audio VAE"
return cls(autoencoder=audio_config, vocoder=vocoder_config)
class ModelDeviceManager:
"""Manages device placement and GPU residency for the composed model."""
def __init__(self, module: torch.nn.Module):
load_device = comfy.model_management.get_torch_device()
offload_device = comfy.model_management.vae_offload_device()
self.patcher = comfy.model_patcher.ModelPatcher(module, load_device, offload_device)
def ensure_model_loaded(self) -> None:
comfy.model_management.free_memory(
self.patcher.model_size(),
self.patcher.load_device,
)
comfy.model_management.load_model_gpu(self.patcher)
def move_to_load_device(self, tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(self.patcher.load_device)
@property
def load_device(self):
return self.patcher.load_device
class AudioLatentNormalizer:
"""Applies per-channel statistics in patch space and restores original layout."""
def __init__(self, patchfier: AudioPatchifier, statistics_processor: torch.nn.Module):
self.patchifier = patchfier
self.statistics = statistics_processor
def normalize(self, latents: torch.Tensor) -> torch.Tensor:
channels = latents.shape[1]
freq = latents.shape[3]
patched, _ = self.patchifier.patchify(latents)
normalized = self.statistics.normalize(patched)
return self.patchifier.unpatchify(normalized, channels=channels, freq=freq)
def denormalize(self, latents: torch.Tensor) -> torch.Tensor:
channels = latents.shape[1]
freq = latents.shape[3]
patched, _ = self.patchifier.patchify(latents)
denormalized = self.statistics.un_normalize(patched)
return self.patchifier.unpatchify(denormalized, channels=channels, freq=freq)
class AudioPreprocessor:
"""Prepares raw waveforms for the autoencoder by matching training conditions."""
def __init__(self, target_sample_rate: int, mel_bins: int, mel_hop_length: int, n_fft: int):
self.target_sample_rate = target_sample_rate
self.mel_bins = mel_bins
self.mel_hop_length = mel_hop_length
self.n_fft = n_fft
def resample(self, waveform: torch.Tensor, source_rate: int) -> torch.Tensor:
if source_rate == self.target_sample_rate:
return waveform
return torchaudio.functional.resample(waveform, source_rate, self.target_sample_rate)
@staticmethod
def normalize_amplitude(
waveform: torch.Tensor, max_amplitude: float = 0.5, eps: float = 1e-5
) -> torch.Tensor:
waveform = waveform - waveform.mean(dim=2, keepdim=True)
peak = torch.max(torch.abs(waveform)) + eps
scale = peak.clamp(max=max_amplitude) / peak
return waveform * scale
def waveform_to_mel(
self, waveform: torch.Tensor, waveform_sample_rate: int, device
) -> torch.Tensor:
waveform = self.resample(waveform, waveform_sample_rate)
waveform = self.normalize_amplitude(waveform)
mel_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=self.target_sample_rate,
n_fft=self.n_fft,
win_length=self.n_fft,
hop_length=self.mel_hop_length,
f_min=0.0,
f_max=self.target_sample_rate / 2.0,
n_mels=self.mel_bins,
window_fn=torch.hann_window,
center=True,
pad_mode="reflect",
power=1.0,
mel_scale="slaney",
norm="slaney",
).to(device)
mel = mel_transform(waveform)
mel = torch.log(torch.clamp(mel, min=1e-5))
return mel.permute(0, 1, 3, 2).contiguous()
class AudioVAE(torch.nn.Module):
"""High-level Audio VAE wrapper exposing encode and decode entry points."""
def __init__(self, state_dict: dict, metadata: dict):
super().__init__()
component_config = AudioVAEComponentConfig.from_metadata(metadata)
vae_sd = utils.state_dict_prefix_replace(state_dict, {"audio_vae.": ""}, filter_keys=True)
vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True)
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
self.vocoder = Vocoder(config=component_config.vocoder)
self.autoencoder.load_state_dict(vae_sd, strict=False)
self.vocoder.load_state_dict(vocoder_sd, strict=False)
autoencoder_config = self.autoencoder.get_config()
self.normalizer = AudioLatentNormalizer(
AudioPatchifier(
patch_size=1,
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
sample_rate=autoencoder_config["sampling_rate"],
hop_length=autoencoder_config["mel_hop_length"],
is_causal=autoencoder_config["is_causal"],
),
self.autoencoder.per_channel_statistics,
)
self.preprocessor = AudioPreprocessor(
target_sample_rate=autoencoder_config["sampling_rate"],
mel_bins=autoencoder_config["mel_bins"],
mel_hop_length=autoencoder_config["mel_hop_length"],
n_fft=autoencoder_config["n_fft"],
)
self.device_manager = ModelDeviceManager(self)
def encode(self, audio: dict) -> torch.Tensor:
"""Encode a waveform dictionary into normalized latent tensors."""
waveform = audio["waveform"]
waveform_sample_rate = audio["sample_rate"]
input_device = waveform.device
# Ensure that Audio VAE is loaded on the correct device.
self.device_manager.ensure_model_loaded()
waveform = self.device_manager.move_to_load_device(waveform)
expected_channels = self.autoencoder.encoder.in_channels
if waveform.shape[1] != expected_channels:
raise ValueError(
f"Input audio must have {expected_channels} channels, got {waveform.shape[1]}"
)
mel_spec = self.preprocessor.waveform_to_mel(
waveform, waveform_sample_rate, device=self.device_manager.load_device
)
latents = self.autoencoder.encode(mel_spec)
posterior = DiagonalGaussianDistribution(latents)
latent_mode = posterior.mode()
normalized = self.normalizer.normalize(latent_mode)
return normalized.to(input_device)
def decode(self, latents: torch.Tensor) -> torch.Tensor:
"""Decode normalized latent tensors into an audio waveform."""
original_shape = latents.shape
# Ensure that Audio VAE is loaded on the correct device.
self.device_manager.ensure_model_loaded()
latents = self.device_manager.move_to_load_device(latents)
latents = self.normalizer.denormalize(latents)
target_shape = self.target_shape_from_latents(original_shape)
mel_spec = self.autoencoder.decode(latents, target_shape=target_shape)
waveform = self.run_vocoder(mel_spec)
return self.device_manager.move_to_load_device(waveform)
def target_shape_from_latents(self, latents_shape):
batch, _, time, _ = latents_shape
target_length = time * LATENT_DOWNSAMPLE_FACTOR
if self.autoencoder.causality_axis != CausalityAxis.NONE:
target_length -= LATENT_DOWNSAMPLE_FACTOR - 1
return (
batch,
self.autoencoder.decoder.out_ch,
target_length,
self.autoencoder.mel_bins,
)
def num_of_latents_from_frames(self, frames_number: int, frame_rate: int) -> int:
return math.ceil((float(frames_number) / frame_rate) * self.latents_per_second)
def run_vocoder(self, mel_spec: torch.Tensor) -> torch.Tensor:
audio_channels = self.autoencoder.decoder.out_ch
vocoder_input = mel_spec.transpose(2, 3)
if audio_channels == 1:
vocoder_input = vocoder_input.squeeze(1)
elif audio_channels != 2:
raise ValueError(f"Unsupported audio_channels: {audio_channels}")
return self.vocoder(vocoder_input)
@property
def sample_rate(self) -> int:
return int(self.autoencoder.sampling_rate)
@property
def mel_hop_length(self) -> int:
return int(self.autoencoder.mel_hop_length)
@property
def mel_bins(self) -> int:
return int(self.autoencoder.mel_bins)
@property
def latent_channels(self) -> int:
return int(self.autoencoder.decoder.z_channels)
@property
def latent_frequency_bins(self) -> int:
return int(self.mel_bins // LATENT_DOWNSAMPLE_FACTOR)
@property
def latents_per_second(self) -> float:
return self.sample_rate / self.mel_hop_length / LATENT_DOWNSAMPLE_FACTOR
@property
def output_sample_rate(self) -> int:
output_rate = getattr(self.vocoder, "output_sample_rate", None)
if output_rate is not None:
return int(output_rate)
upsample_factor = getattr(self.vocoder, "upsample_factor", None)
if upsample_factor is None:
raise AttributeError(
"Vocoder is missing upsample_factor; cannot infer output sample rate"
)
return int(self.sample_rate * upsample_factor / self.mel_hop_length)
def memory_required(self, input_shape):
return self.device_manager.patcher.model_size()

View File

@ -0,0 +1,909 @@
from __future__ import annotations
import torch
from torch import nn
from torch.nn import functional as F
from typing import Optional
from enum import Enum
from .pixel_norm import PixelNorm
import comfy.ops
import logging
ops = comfy.ops.disable_weight_init
class StringConvertibleEnum(Enum):
"""
Base enum class that provides string-to-enum conversion functionality.
This mixin adds a str_to_enum() class method that handles conversion from
strings, None, or existing enum instances with case-insensitive matching.
"""
@classmethod
def str_to_enum(cls, value):
"""
Convert a string, enum instance, or None to the appropriate enum member.
Args:
value: Can be an enum instance of this class, a string, or None
Returns:
Enum member of this class
Raises:
ValueError: If the value cannot be converted to a valid enum member
"""
# Already an enum instance of this class
if isinstance(value, cls):
return value
# None maps to NONE member if it exists
if value is None:
if hasattr(cls, "NONE"):
return cls.NONE
raise ValueError(f"{cls.__name__} does not have a NONE member to map None to")
# String conversion (case-insensitive)
if isinstance(value, str):
value_lower = value.lower()
# Try to match against enum values
for member in cls:
# Handle members with None values
if member.value is None:
if value_lower == "none":
return member
# Handle members with string values
elif isinstance(member.value, str) and member.value.lower() == value_lower:
return member
# Build helpful error message with valid values
valid_values = []
for member in cls:
if member.value is None:
valid_values.append("none")
elif isinstance(member.value, str):
valid_values.append(member.value)
raise ValueError(f"Invalid {cls.__name__} string: '{value}'. " f"Valid values are: {valid_values}")
raise ValueError(
f"Cannot convert type {type(value).__name__} to {cls.__name__} enum. "
f"Expected string, None, or {cls.__name__} instance."
)
class AttentionType(StringConvertibleEnum):
"""Enum for specifying the attention mechanism type."""
VANILLA = "vanilla"
LINEAR = "linear"
NONE = "none"
class CausalityAxis(StringConvertibleEnum):
"""Enum for specifying the causality axis in causal convolutions."""
NONE = None
WIDTH = "width"
HEIGHT = "height"
WIDTH_COMPATIBILITY = "width-compatibility"
def Normalize(in_channels, *, num_groups=32, normtype="group"):
if normtype == "group":
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
elif normtype == "pixel":
return PixelNorm(dim=1, eps=1e-6)
else:
raise ValueError(f"Invalid normalization type: {normtype}")
class CausalConv2d(nn.Module):
"""
A causal 2D convolution.
This layer ensures that the output at time `t` only depends on inputs
at time `t` and earlier. It achieves this by applying asymmetric padding
to the time dimension (width) before the convolution.
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
dilation=1,
groups=1,
bias=True,
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
):
super().__init__()
self.causality_axis = causality_axis
# Ensure kernel_size and dilation are tuples
kernel_size = nn.modules.utils._pair(kernel_size)
dilation = nn.modules.utils._pair(dilation)
# Calculate padding dimensions
pad_h = (kernel_size[0] - 1) * dilation[0]
pad_w = (kernel_size[1] - 1) * dilation[1]
# The padding tuple for F.pad is (pad_left, pad_right, pad_top, pad_bottom)
match self.causality_axis:
case CausalityAxis.NONE:
self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
case CausalityAxis.WIDTH | CausalityAxis.WIDTH_COMPATIBILITY:
self.padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2)
case CausalityAxis.HEIGHT:
self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0)
case _:
raise ValueError(f"Invalid causality_axis: {causality_axis}")
# The internal convolution layer uses no padding, as we handle it manually
self.conv = ops.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=0,
dilation=dilation,
groups=groups,
bias=bias,
)
def forward(self, x):
# Apply causal padding before convolution
x = F.pad(x, self.padding)
return self.conv(x)
def make_conv2d(
in_channels,
out_channels,
kernel_size,
stride=1,
padding=None,
dilation=1,
groups=1,
bias=True,
causality_axis: Optional[CausalityAxis] = None,
):
"""
Create a 2D convolution layer that can be either causal or non-causal.
Args:
in_channels: Number of input channels
out_channels: Number of output channels
kernel_size: Size of the convolution kernel
stride: Convolution stride
padding: Padding (if None, will be calculated based on causal flag)
dilation: Dilation rate
groups: Number of groups for grouped convolution
bias: Whether to use bias
causality_axis: Dimension along which to apply causality.
Returns:
Either a regular Conv2d or CausalConv2d layer
"""
if causality_axis is not None:
# For causal convolution, padding is handled internally by CausalConv2d
return CausalConv2d(in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis)
else:
# For non-causal convolution, use symmetric padding if not specified
if padding is None:
if isinstance(kernel_size, int):
padding = kernel_size // 2
else:
padding = tuple(k // 2 for k in kernel_size)
return ops.Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
)
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv, causality_axis: CausalityAxis = CausalityAxis.HEIGHT):
super().__init__()
self.with_conv = with_conv
self.causality_axis = causality_axis
if self.with_conv:
self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
# Drop FIRST element in the causal axis to undo encoder's padding, while keeping the length 1 + 2 * n.
# For example, if the input is [0, 1, 2], after interpolation, the output is [0, 0, 1, 1, 2, 2].
# The causal convolution will pad the first element as [-, -, 0, 0, 1, 1, 2, 2],
# So the output elements rely on the following windows:
# 0: [-,-,0]
# 1: [-,0,0]
# 2: [0,0,1]
# 3: [0,1,1]
# 4: [1,1,2]
# 5: [1,2,2]
# Notice that the first and second elements in the output rely only on the first element in the input,
# while all other elements rely on two elements in the input.
# So we can drop the first element to undo the padding (rather than the last element).
# This is a no-op for non-causal convolutions.
match self.causality_axis:
case CausalityAxis.NONE:
pass # x remains unchanged
case CausalityAxis.HEIGHT:
x = x[:, :, 1:, :]
case CausalityAxis.WIDTH:
x = x[:, :, :, 1:]
case CausalityAxis.WIDTH_COMPATIBILITY:
pass # x remains unchanged
case _:
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
return x
class Downsample(nn.Module):
"""
A downsampling layer that can use either a strided convolution
or average pooling. Supports standard and causal padding for the
convolutional mode.
"""
def __init__(self, in_channels, with_conv, causality_axis: CausalityAxis = CausalityAxis.WIDTH):
super().__init__()
self.with_conv = with_conv
self.causality_axis = causality_axis
if self.causality_axis != CausalityAxis.NONE and not self.with_conv:
raise ValueError("causality is only supported when `with_conv=True`.")
if self.with_conv:
# Do time downsampling here
# no asymmetric padding in torch conv, must do it ourselves
self.conv = ops.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x):
if self.with_conv:
# (pad_left, pad_right, pad_top, pad_bottom)
match self.causality_axis:
case CausalityAxis.NONE:
pad = (0, 1, 0, 1)
case CausalityAxis.WIDTH:
pad = (2, 0, 0, 1)
case CausalityAxis.HEIGHT:
pad = (0, 1, 2, 0)
case CausalityAxis.WIDTH_COMPATIBILITY:
pad = (1, 0, 0, 1)
case _:
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
# This branch is only taken if with_conv=False, which implies causality_axis is NONE.
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
class ResnetBlock(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout,
temb_channels=512,
norm_type="group",
causality_axis: CausalityAxis = CausalityAxis.HEIGHT,
):
super().__init__()
self.causality_axis = causality_axis
if self.causality_axis != CausalityAxis.NONE and norm_type == "group":
raise ValueError("Causal ResnetBlock with GroupNorm is not supported.")
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels, normtype=norm_type)
self.non_linearity = nn.SiLU()
self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
if temb_channels > 0:
self.temb_proj = ops.Linear(temb_channels, out_channels)
self.norm2 = Normalize(out_channels, normtype=norm_type)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = make_conv2d(
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
)
else:
self.nin_shortcut = make_conv2d(
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
)
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = self.non_linearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None]
h = self.norm2(h)
h = self.non_linearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class AttnBlock(nn.Module):
def __init__(self, in_channels, norm_type="group"):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels, normtype=norm_type)
self.q = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h * w).contiguous()
q = q.permute(0, 2, 1).contiguous() # b,hw,c
k = k.reshape(b, c, h * w).contiguous() # b,c,hw
w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w).contiguous()
w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(v, w_).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b, c, h, w).contiguous()
h_ = self.proj_out(h_)
return x + h_
def make_attn(in_channels, attn_type="vanilla", norm_type="group"):
# Convert string to enum if needed
attn_type = AttentionType.str_to_enum(attn_type)
if attn_type != AttentionType.NONE:
logging.info(f"making attention of type '{attn_type.value}' with {in_channels} in_channels")
else:
logging.info(f"making identity attention with {in_channels} in_channels")
match attn_type:
case AttentionType.VANILLA:
return AttnBlock(in_channels, norm_type=norm_type)
case AttentionType.NONE:
return nn.Identity(in_channels)
case AttentionType.LINEAR:
raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.")
case _:
raise ValueError(f"Unknown attention type: {attn_type}")
class Encoder(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
double_z=True,
attn_type="vanilla",
mid_block_add_attention=True,
norm_type="group",
causality_axis=CausalityAxis.WIDTH.value,
**ignore_kwargs,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.z_channels = z_channels
self.double_z = double_z
self.norm_type = norm_type
# Convert string to enum if needed (for config loading)
causality_axis = CausalityAxis.str_to_enum(causality_axis)
self.attn_type = AttentionType.str_to_enum(attn_type)
# downsampling
self.conv_in = make_conv2d(
in_channels,
self.ch,
kernel_size=3,
stride=1,
causality_axis=causality_axis,
)
self.non_linearity = nn.SiLU()
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=causality_axis,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=causality_axis,
)
if mid_block_add_attention:
self.mid.attn_1 = make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type)
else:
self.mid.attn_1 = nn.Identity()
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=causality_axis,
)
# end
self.norm_out = Normalize(block_in, normtype=self.norm_type)
self.conv_out = make_conv2d(
block_in,
2 * z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
causality_axis=causality_axis,
)
def forward(self, x):
"""
Forward pass through the encoder.
Args:
x: Input tensor of shape [batch, channels, time, n_mels]
Returns:
Encoded latent representation
"""
feature_maps = [self.conv_in(x)]
# Process each resolution level (from high to low resolution)
for resolution_level in range(self.num_resolutions):
# Apply residual blocks at current resolution level
for block_idx in range(self.num_res_blocks):
# Apply ResNet block with optional timestep embedding
current_features = self.down[resolution_level].block[block_idx](feature_maps[-1], temb=None)
# Apply attention if configured for this resolution level
if len(self.down[resolution_level].attn) > 0:
current_features = self.down[resolution_level].attn[block_idx](current_features)
# Store processed features
feature_maps.append(current_features)
# Downsample spatial dimensions (except at the final resolution level)
if resolution_level != self.num_resolutions - 1:
downsampled_features = self.down[resolution_level].downsample(feature_maps[-1])
feature_maps.append(downsampled_features)
# === MIDDLE PROCESSING PHASE ===
# Take the lowest resolution features for middle processing
bottleneck_features = feature_maps[-1]
# Apply first middle ResNet block
bottleneck_features = self.mid.block_1(bottleneck_features, temb=None)
# Apply middle attention block
bottleneck_features = self.mid.attn_1(bottleneck_features)
# Apply second middle ResNet block
bottleneck_features = self.mid.block_2(bottleneck_features, temb=None)
# === OUTPUT PHASE ===
# Normalize the bottleneck features
output_features = self.norm_out(bottleneck_features)
# Apply non-linearity (SiLU activation)
output_features = self.non_linearity(output_features)
# Final convolution to produce latent representation
# [batch, channels, time, n_mels] -> [batch, 2 * z_channels if double_z else z_channels, time, n_mels]
return self.conv_out(output_features)
class Decoder(nn.Module):
def __init__(
self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
tanh_out=False,
attn_type="vanilla",
mid_block_add_attention=True,
norm_type="group",
causality_axis=CausalityAxis.WIDTH.value,
**ignorekwargs,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.out_ch = out_ch
self.give_pre_end = give_pre_end
self.tanh_out = tanh_out
self.norm_type = norm_type
self.z_channels = z_channels
# Convert string to enum if needed (for config loading)
causality_axis = CausalityAxis.str_to_enum(causality_axis)
self.attn_type = AttentionType.str_to_enum(attn_type)
# compute block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
# z to block_in
self.conv_in = make_conv2d(z_channels, block_in, kernel_size=3, stride=1, causality_axis=causality_axis)
self.non_linearity = nn.SiLU()
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=causality_axis,
)
if mid_block_add_attention:
self.mid.attn_1 = make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type)
else:
self.mid.attn_1 = nn.Identity()
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=causality_axis,
)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks + 1):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=causality_axis,
)
)
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in, normtype=self.norm_type)
self.conv_out = make_conv2d(block_in, out_ch, kernel_size=3, stride=1, causality_axis=causality_axis)
def _adjust_output_shape(self, decoded_output, target_shape):
"""
Adjust output shape to match target dimensions for variable-length audio.
This function handles the common case where decoded audio spectrograms need to be
resized to match a specific target shape.
Args:
decoded_output: Tensor of shape (batch, channels, time, frequency)
target_shape: Target shape tuple (batch, channels, time, frequency)
Returns:
Tensor adjusted to match target_shape exactly
"""
# Current output shape: (batch, channels, time, frequency)
_, _, current_time, current_freq = decoded_output.shape
_, target_channels, target_time, target_freq = target_shape
# Step 1: Crop first to avoid exceeding target dimensions
decoded_output = decoded_output[
:, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq)
]
# Step 2: Calculate padding needed for time and frequency dimensions
time_padding_needed = target_time - decoded_output.shape[2]
freq_padding_needed = target_freq - decoded_output.shape[3]
# Step 3: Apply padding if needed
if time_padding_needed > 0 or freq_padding_needed > 0:
# PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom)
# For audio: pad_left/right = frequency, pad_top/bottom = time
padding = (
0,
max(freq_padding_needed, 0), # frequency padding (left, right)
0,
max(time_padding_needed, 0), # time padding (top, bottom)
)
decoded_output = F.pad(decoded_output, padding)
# Step 4: Final safety crop to ensure exact target shape
decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq]
return decoded_output
def get_config(self):
return {
"ch": self.ch,
"out_ch": self.out_ch,
"ch_mult": self.ch_mult,
"num_res_blocks": self.num_res_blocks,
"in_channels": self.in_channels,
"resolution": self.resolution,
"z_channels": self.z_channels,
}
def forward(self, latent_features, target_shape=None):
"""
Decode latent features back to audio spectrograms.
Args:
latent_features: Encoded latent representation of shape (batch, channels, height, width)
target_shape: Optional target output shape (batch, channels, time, frequency)
If provided, output will be cropped/padded to match this shape
Returns:
Reconstructed audio spectrogram of shape (batch, channels, time, frequency)
"""
assert target_shape is not None, "Target shape is required for CausalAudioAutoencoder Decoder"
# Transform latent features to decoder's internal feature dimension
hidden_features = self.conv_in(latent_features)
# Middle processing
hidden_features = self.mid.block_1(hidden_features, temb=None)
hidden_features = self.mid.attn_1(hidden_features)
hidden_features = self.mid.block_2(hidden_features, temb=None)
# Upsampling
# Progressively increase spatial resolution from lowest to highest
for resolution_level in reversed(range(self.num_resolutions)):
# Apply residual blocks at current resolution level
for block_index in range(self.num_res_blocks + 1):
hidden_features = self.up[resolution_level].block[block_index](hidden_features, temb=None)
if len(self.up[resolution_level].attn) > 0:
hidden_features = self.up[resolution_level].attn[block_index](hidden_features)
if resolution_level != 0:
hidden_features = self.up[resolution_level].upsample(hidden_features)
# Output
if self.give_pre_end:
# Return intermediate features before final processing (for debugging/analysis)
decoded_output = hidden_features
else:
# Standard output path: normalize, activate, and convert to output channels
# Final normalization layer
hidden_features = self.norm_out(hidden_features)
# Apply SiLU (Swish) activation function
hidden_features = self.non_linearity(hidden_features)
# Final convolution to map to output channels (typically 2 for stereo audio)
decoded_output = self.conv_out(hidden_features)
# Optional tanh activation to bound output values to [-1, 1] range
if self.tanh_out:
decoded_output = torch.tanh(decoded_output)
# Adjust shape for audio data
if target_shape is not None:
decoded_output = self._adjust_output_shape(decoded_output, target_shape)
return decoded_output
class processor(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("std-of-means", torch.empty(128))
self.register_buffer("mean-of-means", torch.empty(128))
def un_normalize(self, x):
return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x)
def normalize(self, x):
return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x)
class CausalAudioAutoencoder(nn.Module):
def __init__(self, config=None):
super().__init__()
if config is None:
config = self._guess_config()
# Extract encoder and decoder configs from the new format
model_config = config.get("model", {}).get("params", {})
variables_config = config.get("variables", {})
self.sampling_rate = variables_config.get(
"sampling_rate",
model_config.get("sampling_rate", config.get("sampling_rate", 16000)),
)
encoder_config = model_config.get("encoder", model_config.get("ddconfig", {}))
decoder_config = model_config.get("decoder", encoder_config)
# Load mel spectrogram parameters
self.mel_bins = encoder_config.get("mel_bins", 64)
self.mel_hop_length = model_config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160)
self.n_fft = model_config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024)
# Store causality configuration at VAE level (not just in encoder internals)
causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.WIDTH.value)
self.causality_axis = CausalityAxis.str_to_enum(causality_axis_value)
self.is_causal = self.causality_axis == CausalityAxis.HEIGHT
self.encoder = Encoder(**encoder_config)
self.decoder = Decoder(**decoder_config)
self.per_channel_statistics = processor()
def _guess_config(self):
encoder_config = {
# Required parameters - based on ltx-video-av-1679000 model metadata
"ch": 128,
"out_ch": 8,
"ch_mult": [1, 2, 4], # Based on metadata: [1, 2, 4] not [1, 2, 4, 8]
"num_res_blocks": 2,
"attn_resolutions": [], # Based on metadata: empty list, no attention
"dropout": 0.0,
"resamp_with_conv": True,
"in_channels": 2, # stereo
"resolution": 256,
"z_channels": 8,
"double_z": True,
"attn_type": "vanilla",
"mid_block_add_attention": False, # Based on metadata: false
"norm_type": "pixel",
"causality_axis": "height", # Based on metadata
"mel_bins": 64, # Based on metadata: mel_bins = 64
}
decoder_config = {
# Inherits encoder config, can override specific params
**encoder_config,
"out_ch": 2, # Stereo audio output (2 channels)
"give_pre_end": False,
"tanh_out": False,
}
config = {
"_class_name": "CausalAudioAutoencoder",
"sampling_rate": 16000,
"model": {
"params": {
"encoder": encoder_config,
"decoder": decoder_config,
}
},
}
return config
def get_config(self):
return {
"sampling_rate": self.sampling_rate,
"mel_bins": self.mel_bins,
"mel_hop_length": self.mel_hop_length,
"n_fft": self.n_fft,
"causality_axis": self.causality_axis.value,
"is_causal": self.is_causal,
}
def encode(self, x):
return self.encoder(x)
def decode(self, x, target_shape=None):
return self.decoder(x, target_shape=target_shape)

View File

@ -0,0 +1,213 @@
import torch
import torch.nn.functional as F
import torch.nn as nn
import comfy.ops
import numpy as np
ops = comfy.ops.disable_weight_init
LRELU_SLOPE = 0.1
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
class ResBlock1(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
super(ResBlock1, self).__init__()
self.convs1 = nn.ModuleList(
[
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
),
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
),
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]),
),
]
)
self.convs2 = nn.ModuleList(
[
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
),
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
),
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
),
]
)
def forward(self, x):
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c1(xt)
xt = F.leaky_relu(xt, LRELU_SLOPE)
xt = c2(xt)
x = xt + x
return x
class ResBlock2(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
super(ResBlock2, self).__init__()
self.convs = nn.ModuleList(
[
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
),
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
),
]
)
def forward(self, x):
for c in self.convs:
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c(xt)
x = xt + x
return x
class Vocoder(torch.nn.Module):
"""
Vocoder model for synthesizing audio from spectrograms, based on: https://github.com/jik876/hifi-gan.
"""
def __init__(self, config=None):
super(Vocoder, self).__init__()
if config is None:
config = self.get_default_config()
resblock_kernel_sizes = config.get("resblock_kernel_sizes", [3, 7, 11])
upsample_rates = config.get("upsample_rates", [6, 5, 2, 2, 2])
upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 15, 8, 4, 4])
resblock_dilation_sizes = config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
upsample_initial_channel = config.get("upsample_initial_channel", 1024)
stereo = config.get("stereo", True)
resblock = config.get("resblock", "1")
self.output_sample_rate = config.get("output_sample_rate")
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
in_channels = 128 if stereo else 64
self.conv_pre = ops.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
resblock_class = ResBlock1 if resblock == "1" else ResBlock2
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
self.ups.append(
ops.ConvTranspose1d(
upsample_initial_channel // (2**i),
upsample_initial_channel // (2 ** (i + 1)),
k,
u,
padding=(k - u) // 2,
)
)
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock_class(ch, k, d))
out_channels = 2 if stereo else 1
self.conv_post = ops.Conv1d(ch, out_channels, 7, 1, padding=3)
self.upsample_factor = np.prod([self.ups[i].stride[0] for i in range(len(self.ups))])
def get_default_config(self):
"""Generate default configuration for the vocoder."""
config = {
"resblock_kernel_sizes": [3, 7, 11],
"upsample_rates": [6, 5, 2, 2, 2],
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"upsample_initial_channel": 1024,
"stereo": True,
"resblock": "1",
}
return config
def forward(self, x):
"""
Forward pass of the vocoder.
Args:
x: Input spectrogram tensor. Can be:
- 3D: (batch_size, channels, time_steps) for mono
- 4D: (batch_size, 2, channels, time_steps) for stereo
Returns:
Audio tensor of shape (batch_size, out_channels, audio_length)
"""
if x.dim() == 4: # stereo
assert x.shape[1] == 2, "Input must have 2 channels for stereo"
x = torch.cat((x[:, 0, :, :], x[:, 1, :, :]), dim=1)
x = self.conv_pre(x)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, LRELU_SLOPE)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x

View File

@ -394,7 +394,8 @@ class Model(nn.Module):
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
super().__init__()
if use_linear_attn: attn_type = "linear"
if use_linear_attn:
attn_type = "linear"
self.ch = ch
self.temb_ch = self.ch*4
self.num_resolutions = len(ch_mult)
@ -548,7 +549,8 @@ class Encoder(nn.Module):
conv3d=False, time_compress=None,
**ignore_kwargs):
super().__init__()
if use_linear_attn: attn_type = "linear"
if use_linear_attn:
attn_type = "linear"
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)

View File

@ -45,7 +45,7 @@ class LitEma(nn.Module):
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
else:
assert not key in self.m_name2s_name
assert key not in self.m_name2s_name
def copy_to(self, model):
m_param = dict(model.named_parameters())
@ -54,7 +54,7 @@ class LitEma(nn.Module):
if m_param[key].requires_grad:
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
else:
assert not key in self.m_name2s_name
assert key not in self.m_name2s_name
def store(self, parameters):
"""

View File

@ -71,7 +71,7 @@ def count_params(model, verbose=False):
def instantiate_from_config(config):
if not "target" in config:
if "target" not in config:
if config == '__is_first_stage__':
return None
elif config == "__is_unconditional__":

View File

@ -322,6 +322,7 @@ def model_lora_keys_unet(model, key_map={}):
key_map["diffusion_model.{}".format(key_lora)] = to
key_map["transformer.{}".format(key_lora)] = to
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
key_map[key_lora] = to
if isinstance(model, comfy.model_base.Kandinsky5):
for k in sdk:

View File

@ -20,6 +20,7 @@ import comfy.ldm.hunyuan3dv2_1
import comfy.ldm.hunyuan3dv2_1.hunyuandit
import torch
import logging
import comfy.ldm.lightricks.av_model
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
from comfy.ldm.cascade.stage_c import StageC
from comfy.ldm.cascade.stage_b import StageB
@ -946,7 +947,7 @@ class GenmoMochi(BaseModel):
class LTXV(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel) #TODO
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
@ -977,6 +978,60 @@ class LTXV(BaseModel):
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
return latent_image
class LTXAV(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.av_model.LTXAVModel) #TODO
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
audio_denoise_mask = None
if denoise_mask is not None and "latent_shapes" in kwargs:
denoise_mask = utils.unpack_latents(denoise_mask, kwargs["latent_shapes"])
if len(denoise_mask) > 1:
audio_denoise_mask = denoise_mask[1]
denoise_mask = denoise_mask[0]
if denoise_mask is not None:
out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask)
if audio_denoise_mask is not None:
out["audio_denoise_mask"] = comfy.conds.CONDRegular(audio_denoise_mask)
keyframe_idxs = kwargs.get("keyframe_idxs", None)
if keyframe_idxs is not None:
out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
latent_shapes = kwargs.get("latent_shapes", None)
if latent_shapes is not None:
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
return out
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
v_timestep = timestep
a_timestep = timestep
if denoise_mask is not None:
v_timestep = self.diffusion_model.patchifier.patchify(((denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1)))[:, :1])[0]
if audio_denoise_mask is not None:
a_timestep = self.diffusion_model.a_patchifier.patchify(((audio_denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (audio_denoise_mask.ndim - 1)))[:, :1, :, :1])[0]
return v_timestep, a_timestep
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
return latent_image
class HunyuanVideo(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)

View File

@ -237,6 +237,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
else:
dit_config["vec_in_dim"] = None
dit_config["num_heads"] = dit_config["hidden_size"] // sum(dit_config["axes_dim"])
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
@ -305,7 +307,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
dit_config = {}
dit_config["image_model"] = "ltxv"
dit_config["image_model"] = "ltxav" if f'{key_prefix}audio_adaln_single.linear.weight' in state_dict_keys else "ltxv"
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
shape = state_dict['{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix)].shape
dit_config["attention_head_dim"] = shape[0] // 32

View File

@ -22,7 +22,6 @@ from enum import Enum
from comfy.cli_args import args, PerformanceFeature
import torch
import sys
import importlib
import platform
import weakref
import gc
@ -349,15 +348,27 @@ try:
except:
rocm_version = (6, -1)
def aotriton_supported(gpu_arch):
path = torch.__path__[0]
path = os.path.join(os.path.join(path, "lib"), "aotriton.images")
gfx = set(map(lambda a: a[4:], filter(lambda a: a.startswith("amd-gfx"), os.listdir(path))))
if gpu_arch in gfx:
return True
if "{}x".format(gpu_arch[:-1]) in gfx:
return True
if "{}xx".format(gpu_arch[:-2]) in gfx:
return True
return False
logging.info("AMD arch: {}".format(arch))
logging.info("ROCm version: {}".format(rocm_version))
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
if importlib.util.find_spec('triton') is not None: # AMD efficient attention implementation depends on triton. TODO: better way of detecting if it's compiled in or not.
if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton.
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
ENABLE_PYTORCH_ATTENTION = True
if rocm_version >= (7, 0):
if any((a in arch) for a in ["gfx1201"]):
if any((a in arch) for a in ["gfx1200", "gfx1201"]):
ENABLE_PYTORCH_ATTENTION = True
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx950"]): # TODO: more arches, "gfx942" gives error on pytorch nightly 2.10 1013 rocm7.0
@ -456,7 +467,7 @@ def module_size(module):
sd = module.state_dict()
for k in sd:
t = sd[k]
module_mem += t.nelement() * t.element_size()
module_mem += t.nbytes
return module_mem
class LoadedModel:
@ -1156,7 +1167,7 @@ def pin_memory(tensor):
if not tensor.is_contiguous():
return False
size = tensor.numel() * tensor.element_size()
size = tensor.nbytes
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
return False
@ -1183,7 +1194,7 @@ def unpin_memory(tensor):
return False
ptr = tensor.data_ptr()
size = tensor.numel() * tensor.element_size()
size = tensor.nbytes
size_stored = PINNED_MEMORY.get(ptr, None)
if size_stored is None:
@ -1504,6 +1515,16 @@ def supports_fp8_compute(device=None):
return True
def supports_nvfp4_compute(device=None):
if not is_nvidia():
return False
props = torch.cuda.get_device_properties(device)
if props.major < 10:
return False
return True
def extended_fp16_support():
# TODO: check why some models work with fp16 on newer torch versions but not on older
if torch_version_numeric < (2, 7):
@ -1542,6 +1563,10 @@ def soft_empty_cache(force=False):
def unload_all_models():
free_memory(1e30, get_torch_device())
def debug_memory_summary():
if is_amd() or is_nvidia():
return torch.cuda.memory.memory_summary()
return ""
#TODO: might be cleaner to put this somewhere else
import threading

View File

@ -718,6 +718,7 @@ class ModelPatcher:
continue
cast_weight = self.force_cast_weights
m.comfy_force_cast_weights = self.force_cast_weights
if lowvram_weight:
if hasattr(m, "comfy_cast_weights"):
m.weight_function = []
@ -790,11 +791,12 @@ class ModelPatcher:
for param in params:
self.pin_weight_to_device("{}.{}".format(n, param))
usable_stat = "{:.2f} MB usable,".format(lowvram_model_memory / (1024 * 1024)) if lowvram_model_memory < 1e32 else ""
if lowvram_counter > 0:
logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter))
logging.info("loaded partially; {} {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(usable_stat, mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter))
self.model.model_lowvram = True
else:
logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
logging.info("loaded completely; {} {:.2f} MB loaded, full load: {}".format(usable_stat, mem_counter / (1024 * 1024), full_load))
self.model.model_lowvram = False
if full_load:
self.model.to(device_to)

View File

@ -79,7 +79,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if input is not None:
if dtype is None:
if isinstance(input, QuantizedTensor):
dtype = input._layout_params["orig_dtype"]
dtype = input.params.orig_dtype
else:
dtype = input.dtype
if bias_dtype is None:
@ -412,26 +412,34 @@ def fp8_linear(self, input):
return None
input_dtype = input.dtype
input_shape = input.shape
tensor_3d = input.ndim == 3
if input.ndim == 3 or input.ndim == 2:
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
if tensor_3d:
input = input.reshape(-1, input_shape[2])
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
input = torch.clamp(input, min=-448, max=448, out=input)
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
if input.ndim != 2:
return None
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
# Wrap weight in QuantizedTensor - this enables unified dispatch
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
input = torch.clamp(input, min=-448, max=448, out=input)
input_fp8 = input.to(dtype).contiguous()
layout_params_input = TensorCoreFP8Layout.Params(scale=scale_input, orig_dtype=input_dtype, orig_shape=tuple(input_fp8.shape))
quantized_input = QuantizedTensor(input_fp8, "TensorCoreFP8Layout", layout_params_input)
uncast_bias_weight(self, w, bias, offload_stream)
return o
# Wrap weight in QuantizedTensor - this enables unified dispatch
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
layout_params_weight = TensorCoreFP8Layout.Params(scale=scale_weight, orig_dtype=input_dtype, orig_shape=tuple(w.shape))
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
return None
uncast_bias_weight(self, w, bias, offload_stream)
if tensor_3d:
o = o.reshape((input_shape[0], input_shape[1], w.shape[0]))
return o
class fp8_ops(manual_cast):
class Linear(manual_cast.Linear):
@ -477,14 +485,20 @@ if CUBLAS_IS_AVAILABLE:
# ==============================================================================
# Mixed Precision Operations
# ==============================================================================
from .quant_ops import QuantizedTensor, QUANT_ALGOS
from .quant_ops import (
QuantizedTensor,
QUANT_ALGOS,
TensorCoreFP8Layout,
get_layout_class,
)
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
class MixedPrecisionOps(manual_cast):
_quant_config = quant_config
_compute_dtype = compute_dtype
_full_precision_mm = full_precision_mm
_disabled = disabled
class Linear(torch.nn.Module, CastWeightBiasOp):
def __init__(
@ -497,21 +511,33 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
) -> None:
super().__init__()
if dtype is None:
dtype = MixedPrecisionOps._compute_dtype
self.factory_kwargs = {"device": device, "dtype": dtype}
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
# self.factory_kwargs = {"device": device, "dtype": dtype}
self.in_features = in_features
self.out_features = out_features
self._has_bias = bias
if bias:
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
else:
self.register_parameter("bias", None)
self.tensor_class = None
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
self._full_precision_mm_config = False
def reset_parameters(self):
return None
def _load_scale_param(self, state_dict, prefix, param_name, device, manually_loaded_keys, dtype=None):
key = f"{prefix}{param_name}"
value = state_dict.pop(key, None)
if value is not None:
value = value.to(device=device)
if dtype is not None:
value = value.view(dtype=dtype)
manually_loaded_keys.append(key)
return value
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
@ -520,7 +546,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
weight_key = f"{prefix}weight"
weight = state_dict.pop(weight_key, None)
if weight is None:
raise ValueError(f"Missing weight for layer {layer_name}")
logging.warning(f"Missing weight for layer {layer_name}")
return
manually_loaded_keys = [weight_key]
@ -529,49 +556,61 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
layer_conf = json.loads(layer_conf.numpy().tobytes())
if layer_conf is None:
dtype = self.factory_kwargs["dtype"]
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=dtype), requires_grad=False)
if dtype != MixedPrecisionOps._compute_dtype:
self.comfy_cast_weights = True
if self._has_bias:
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=dtype))
else:
self.register_parameter("bias", None)
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
else:
self.quant_format = layer_conf.get("format", None)
self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
if not self._full_precision_mm:
self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False)
self._full_precision_mm = self._full_precision_mm_config
if self.quant_format in MixedPrecisionOps._disabled:
self._full_precision_mm = True
if self.quant_format is None:
raise ValueError(f"Unknown quantization format for layer {layer_name}")
qconfig = QUANT_ALGOS[self.quant_format]
self.layout_type = qconfig["comfy_tensor_layout"]
layout_cls = get_layout_class(self.layout_type)
weight_scale_key = f"{prefix}weight_scale"
scale = state_dict.pop(weight_scale_key, None)
if scale is not None:
scale = scale.to(device)
layout_params = {
'scale': scale,
'orig_dtype': MixedPrecisionOps._compute_dtype,
'block_size': qconfig.get("group_size", None),
}
# Load format-specific parameters
if self.quant_format in ["float8_e4m3fn", "float8_e5m2"]:
# FP8: single tensor scale
scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
if scale is not None:
manually_loaded_keys.append(weight_scale_key)
params = layout_cls.Params(
scale=scale,
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=(self.out_features, self.in_features),
)
elif self.quant_format == "nvfp4":
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
dtype=torch.float8_e4m3fn)
if tensor_scale is None or block_scale is None:
raise ValueError(f"Missing NVFP4 scales for layer {layer_name}")
params = layout_cls.Params(
scale=tensor_scale,
block_scale=block_scale,
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=(self.out_features, self.in_features),
)
else:
raise ValueError(f"Unsupported quantization format: {self.quant_format}")
self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device, dtype=qconfig.get("storage_t", None)), self.layout_type, layout_params),
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params),
requires_grad=False
)
if self._has_bias:
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=MixedPrecisionOps._compute_dtype))
else:
self.register_parameter("bias", None)
for param_name in qconfig["parameters"]:
if param_name in {"weight_scale", "weight_scale_2"}:
continue # Already handled above
param_key = f"{prefix}{param_name}"
_v = state_dict.pop(param_key, None)
if _v is None:
@ -586,13 +625,29 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
missing_keys.remove(key)
def state_dict(self, *args, destination=None, prefix="", **kwargs):
sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs)
if destination is not None:
sd = destination
else:
sd = {}
if self.bias is not None:
sd["{}bias".format(prefix)] = self.bias
if isinstance(self.weight, QuantizedTensor):
sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale']
sd_out = self.weight.state_dict("{}weight".format(prefix))
for k in sd_out:
sd[k] = sd_out[k]
quant_conf = {"format": self.quant_format}
if self._full_precision_mm:
if self._full_precision_mm_config:
quant_conf["full_precision_matrix_mult"] = True
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
input_scale = getattr(self, 'input_scale', None)
if input_scale is not None:
sd["{}input_scale".format(prefix)] = input_scale
else:
sd["{}weight".format(prefix)] = self.weight
return sd
def _forward(self, input, weight, bias):
@ -607,12 +662,33 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def forward(self, input, *args, **kwargs):
run_every_op()
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(input, *args, **kwargs)
input_shape = input.shape
reshaped_3d = False
if (getattr(self, 'layout_type', None) is not None and
not isinstance(input, QuantizedTensor)):
input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None), dtype=self.weight.dtype)
return self._forward(input, self.weight, self.bias)
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
not getattr(self, 'comfy_force_cast_weights', False) and
len(self.weight_function) == 0 and len(self.bias_function) == 0):
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
# Fall back to non-quantized for non-2D tensors
if input_reshaped.ndim == 2:
reshaped_3d = input.ndim == 3
# dtype is now implicit in the layout class
scale = getattr(self, 'input_scale', None)
if scale is not None:
scale = comfy.model_management.cast_to_device(scale, input.device, None)
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
output = self.forward_comfy_cast_weights(input)
# Reshape output back to 3D if input was 3D
if reshaped_3d:
output = output.reshape((input_shape[0], input_shape[1], self.weight.shape[0]))
return output
def convert_weight(self, weight, inplace=False, **kwargs):
if isinstance(weight, QuantizedTensor):
@ -622,7 +698,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
if getattr(self, 'layout_type', None) is not None:
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
# dtype is now implicit in the layout class
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
else:
weight = weight.to(self.weight.dtype)
if return_weight:
@ -649,10 +726,17 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device)
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
logging.info("Using mixed precision operations")
return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute)
disabled = set()
if not nvfp4_compute:
disabled.add("nvfp4")
if not fp8_compute:
disabled.add("float8_e4m3fn")
disabled.add("float8_e5m2")
return mixed_precision_ops(model_config.quant_config, compute_dtype, disabled=disabled)
if (
fp8_compute and

View File

@ -1,580 +1,174 @@
import torch
import logging
from typing import Tuple, Dict
try:
import comfy_kitchen as ck
from comfy_kitchen.tensor import (
QuantizedTensor,
QuantizedLayout,
TensorCoreFP8Layout as _CKFp8Layout,
TensorCoreNVFP4Layout as _CKNvfp4Layout,
register_layout_op,
register_layout_class,
get_layout_class,
)
_CK_AVAILABLE = True
if torch.version.cuda is None:
ck.registry.disable("cuda")
else:
cuda_version = tuple(map(int, str(torch.version.cuda).split('.')))
if cuda_version < (13,):
ck.registry.disable("cuda")
logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.")
ck.registry.disable("triton")
for k, v in ck.list_backends().items():
logging.info(f"Found comfy_kitchen backend {k}: {v}")
except ImportError as e:
logging.error(f"Failed to import comfy_kitchen, Error: {e}, fp8 and fp4 support will not be available.")
_CK_AVAILABLE = False
class QuantizedTensor:
pass
class _CKFp8Layout:
pass
class _CKNvfp4Layout:
pass
def register_layout_class(name, cls):
pass
def get_layout_class(name):
return None
import comfy.float
_LAYOUT_REGISTRY = {}
_GENERIC_UTILS = {}
def register_layout_op(torch_op, layout_type):
"""
Decorator to register a layout-specific operation handler.
Args:
torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default)
layout_type: Layout class (e.g., TensorCoreFP8Layout)
Example:
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
def fp8_linear(func, args, kwargs):
# FP8-specific linear implementation
...
"""
def decorator(handler_func):
if torch_op not in _LAYOUT_REGISTRY:
_LAYOUT_REGISTRY[torch_op] = {}
_LAYOUT_REGISTRY[torch_op][layout_type] = handler_func
return handler_func
return decorator
def register_generic_util(torch_op):
"""
Decorator to register a generic utility that works for all layouts.
Args:
torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default)
Example:
@register_generic_util(torch.ops.aten.detach.default)
def generic_detach(func, args, kwargs):
# Works for any layout
...
"""
def decorator(handler_func):
_GENERIC_UTILS[torch_op] = handler_func
return handler_func
return decorator
def _get_layout_from_args(args):
for arg in args:
if isinstance(arg, QuantizedTensor):
return arg._layout_type
elif isinstance(arg, (list, tuple)):
for item in arg:
if isinstance(item, QuantizedTensor):
return item._layout_type
return None
def _move_layout_params_to_device(params, device):
new_params = {}
for k, v in params.items():
if isinstance(v, torch.Tensor):
new_params[k] = v.to(device=device)
else:
new_params[k] = v
return new_params
def _copy_layout_params(params):
new_params = {}
for k, v in params.items():
if isinstance(v, torch.Tensor):
new_params[k] = v.clone()
else:
new_params[k] = v
return new_params
def _copy_layout_params_inplace(src, dst, non_blocking=False):
for k, v in src.items():
if isinstance(v, torch.Tensor):
dst[k].copy_(v, non_blocking=non_blocking)
else:
dst[k] = v
class QuantizedLayout:
"""
Base class for quantization layouts.
A layout encapsulates the format-specific logic for quantization/dequantization
and provides a uniform interface for extracting raw tensors needed for computation.
New quantization formats should subclass this and implement the required methods.
"""
@classmethod
def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]:
raise NotImplementedError(f"{cls.__name__} must implement quantize()")
@staticmethod
def dequantize(qdata, **layout_params) -> torch.Tensor:
raise NotImplementedError("TensorLayout must implement dequantize()")
@classmethod
def get_plain_tensors(cls, qtensor) -> torch.Tensor:
raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()")
class QuantizedTensor(torch.Tensor):
"""
Universal quantized tensor that works with any layout.
This tensor subclass uses a pluggable layout system to support multiple
quantization formats (FP8, INT4, INT8, etc.) without code duplication.
The layout_type determines format-specific behavior, while common operations
(detach, clone, to) are handled generically.
Attributes:
_qdata: The quantized tensor data
_layout_type: Layout class (e.g., TensorCoreFP8Layout)
_layout_params: Dict with layout-specific params (scale, zero_point, etc.)
"""
@staticmethod
def __new__(cls, qdata, layout_type, layout_params):
"""
Create a quantized tensor.
Args:
qdata: The quantized data tensor
layout_type: Layout class (subclass of QuantizedLayout)
layout_params: Dict with layout-specific parameters
"""
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
def __init__(self, qdata, layout_type, layout_params):
self._qdata = qdata
self._layout_type = layout_type
self._layout_params = layout_params
def __repr__(self):
layout_name = self._layout_type
param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2])
return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})"
@property
def layout_type(self):
return self._layout_type
def __tensor_flatten__(self):
"""
Tensor flattening protocol for proper device movement.
"""
inner_tensors = ["_qdata"]
ctx = {
"layout_type": self._layout_type,
}
tensor_params = {}
non_tensor_params = {}
for k, v in self._layout_params.items():
if isinstance(v, torch.Tensor):
tensor_params[k] = v
else:
non_tensor_params[k] = v
ctx["tensor_param_keys"] = list(tensor_params.keys())
ctx["non_tensor_params"] = non_tensor_params
for k, v in tensor_params.items():
attr_name = f"_layout_param_{k}"
object.__setattr__(self, attr_name, v)
inner_tensors.append(attr_name)
return inner_tensors, ctx
@staticmethod
def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
"""
Tensor unflattening protocol for proper device movement.
Reconstructs the QuantizedTensor after device movement.
"""
layout_type = ctx["layout_type"]
layout_params = dict(ctx["non_tensor_params"])
for key in ctx["tensor_param_keys"]:
attr_name = f"_layout_param_{key}"
layout_params[key] = inner_tensors[attr_name]
return QuantizedTensor(inner_tensors["_qdata"], layout_type, layout_params)
@classmethod
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs)
return cls(qdata, layout_type, layout_params)
def dequantize(self) -> torch.Tensor:
return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
# Step 1: Check generic utilities first (detach, clone, to, etc.)
if func in _GENERIC_UTILS:
return _GENERIC_UTILS[func](func, args, kwargs)
# Step 2: Check layout-specific handlers (linear, matmul, etc.)
layout_type = _get_layout_from_args(args)
if layout_type and func in _LAYOUT_REGISTRY:
handler = _LAYOUT_REGISTRY[func].get(layout_type)
if handler:
return handler(func, args, kwargs)
# Step 3: Fallback to dequantization
if isinstance(args[0] if args else None, QuantizedTensor):
logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}")
return cls._dequant_and_fallback(func, args, kwargs)
@classmethod
def _dequant_and_fallback(cls, func, args, kwargs):
def dequant_arg(arg):
if isinstance(arg, QuantizedTensor):
return arg.dequantize()
elif isinstance(arg, (list, tuple)):
return type(arg)(dequant_arg(a) for a in arg)
return arg
new_args = dequant_arg(args)
new_kwargs = dequant_arg(kwargs)
return func(*new_args, **new_kwargs)
def data_ptr(self):
return self._qdata.data_ptr()
def is_pinned(self):
return self._qdata.is_pinned()
def is_contiguous(self, *arg, **kwargs):
return self._qdata.is_contiguous(*arg, **kwargs)
def storage(self):
return self._qdata.storage()
# ==============================================================================
# Generic Utilities (Layout-Agnostic Operations)
# FP8 Layouts with Comfy-Specific Extensions
# ==============================================================================
def _create_transformed_qtensor(qt, transform_fn):
new_data = transform_fn(qt._qdata)
new_params = _copy_layout_params(qt._layout_params)
return QuantizedTensor(new_data, qt._layout_type, new_params)
class _TensorCoreFP8LayoutBase(_CKFp8Layout):
FP8_DTYPE = None # Must be overridden in subclass
def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
if target_layout is not None and target_layout != torch.strided:
logging.warning(
f"QuantizedTensor: layout change requested to {target_layout}, "
f"but not supported. Ignoring layout."
)
# Handle device transfer
current_device = qt._qdata.device
if target_device is not None:
# Normalize device for comparison
if isinstance(target_device, str):
target_device = torch.device(target_device)
if isinstance(current_device, str):
current_device = torch.device(current_device)
if target_device != current_device:
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
new_q_data = qt._qdata.to(device=target_device)
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
if target_dtype is not None:
new_params["orig_dtype"] = target_dtype
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
return new_qt
logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original")
return qt
@register_generic_util(torch.ops.aten.detach.default)
def generic_detach(func, args, kwargs):
"""Detach operation - creates a detached copy of the quantized tensor."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _create_transformed_qtensor(qt, lambda x: x.detach())
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.clone.default)
def generic_clone(func, args, kwargs):
"""Clone operation - creates a deep copy of the quantized tensor."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _create_transformed_qtensor(qt, lambda x: x.clone())
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten._to_copy.default)
def generic_to_copy(func, args, kwargs):
"""Device/dtype transfer operation - handles .to(device) calls."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _handle_device_transfer(
qt,
target_device=kwargs.get('device', None),
target_dtype=kwargs.get('dtype', None),
op_name="_to_copy"
)
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.to.dtype_layout)
def generic_to_dtype_layout(func, args, kwargs):
"""Handle .to(device) calls using the dtype_layout variant."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _handle_device_transfer(
qt,
target_device=kwargs.get('device', None),
target_dtype=kwargs.get('dtype', None),
target_layout=kwargs.get('layout', None),
op_name="to"
)
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.copy_.default)
def generic_copy_(func, args, kwargs):
qt_dest = args[0]
src = args[1]
non_blocking = args[2] if len(args) > 2 else False
if isinstance(qt_dest, QuantizedTensor):
if isinstance(src, QuantizedTensor):
# Copy from another quantized tensor
qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking)
qt_dest._layout_type = src._layout_type
orig_dtype = qt_dest._layout_params["orig_dtype"]
_copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking)
qt_dest._layout_params["orig_dtype"] = orig_dtype
else:
# Copy from regular tensor - just copy raw data
qt_dest._qdata.copy_(src)
return qt_dest
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.to.dtype)
def generic_to_dtype(func, args, kwargs):
"""Handle .to(dtype) calls - dtype conversion only."""
src = args[0]
if isinstance(src, QuantizedTensor):
# For dtype-only conversion, just change the orig_dtype, no real cast is needed
target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype')
src._layout_params["orig_dtype"] = target_dtype
return src
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default)
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
return True
@register_generic_util(torch.ops.aten.empty_like.default)
def generic_empty_like(func, args, kwargs):
"""Empty_like operation - creates an empty tensor with the same quantized structure."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
# Create empty tensor with same shape and dtype as the quantized data
hp_dtype = kwargs.pop('dtype', qt._layout_params["orig_dtype"])
new_qdata = torch.empty_like(qt._qdata, **kwargs)
# Handle device transfer for layout params
target_device = kwargs.get('device', new_qdata.device)
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
# Update orig_dtype if dtype is specified
new_params['orig_dtype'] = hp_dtype
return QuantizedTensor(new_qdata, qt._layout_type, new_params)
return func(*args, **kwargs)
# ==============================================================================
# FP8 Layout + Operation Handlers
# ==============================================================================
class TensorCoreFP8Layout(QuantizedLayout):
"""
Storage format:
- qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2)
- scale: Scalar tensor (float32) for dequantization
- orig_dtype: Original dtype before quantization (for casting back)
"""
@classmethod
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False):
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
if cls.FP8_DTYPE is None:
raise NotImplementedError(f"{cls.__name__} must define FP8_DTYPE")
orig_dtype = tensor.dtype
orig_shape = tuple(tensor.shape)
if isinstance(scale, str) and scale == "recalculate":
scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(dtype).max
scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(cls.FP8_DTYPE).max
if tensor.dtype not in [torch.float32, torch.bfloat16]: # Prevent scale from being too small
tensor_info = torch.finfo(tensor.dtype)
scale = (1.0 / torch.clamp((1.0 / scale), min=tensor_info.min, max=tensor_info.max))
if scale is not None:
if not isinstance(scale, torch.Tensor):
scale = torch.tensor(scale)
scale = scale.to(device=tensor.device, dtype=torch.float32)
if scale is None:
scale = torch.ones((), device=tensor.device, dtype=torch.float32)
if not isinstance(scale, torch.Tensor):
scale = torch.tensor(scale, device=tensor.device, dtype=torch.float32)
if stochastic_rounding > 0:
if inplace_ops:
tensor *= (1.0 / scale).to(tensor.dtype)
else:
tensor = tensor * (1.0 / scale).to(tensor.dtype)
qdata = comfy.float.stochastic_rounding(tensor, dtype=cls.FP8_DTYPE, seed=stochastic_rounding)
else:
scale = torch.ones((), device=tensor.device, dtype=torch.float32)
qdata = ck.quantize_per_tensor_fp8(tensor, scale, cls.FP8_DTYPE)
params = cls.Params(scale=scale.float(), orig_dtype=orig_dtype, orig_shape=orig_shape)
return qdata, params
class TensorCoreNVFP4Layout(_CKNvfp4Layout):
@classmethod
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
if tensor.dim() != 2:
raise ValueError(f"NVFP4 requires 2D tensor, got {tensor.dim()}D")
orig_dtype = tensor.dtype
orig_shape = tuple(tensor.shape)
if scale is None or (isinstance(scale, str) and scale == "recalculate"):
scale = torch.amax(tensor.abs()) / (ck.float_utils.F8_E4M3_MAX * ck.float_utils.F4_E2M1_MAX)
if not isinstance(scale, torch.Tensor):
scale = torch.tensor(scale)
scale = scale.to(device=tensor.device, dtype=torch.float32)
padded_shape = cls.get_padded_shape(orig_shape)
needs_padding = padded_shape != orig_shape
if stochastic_rounding > 0:
tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding)
qdata, block_scale = comfy.float.stochastic_round_quantize_nvfp4(tensor, scale, pad_16x=needs_padding, seed=stochastic_rounding)
else:
lp_amax = torch.finfo(dtype).max
torch.clamp(tensor, min=-lp_amax, max=lp_amax, out=tensor)
tensor = tensor.to(dtype, memory_format=torch.contiguous_format)
qdata, block_scale = ck.quantize_nvfp4(tensor, scale, pad_16x=needs_padding)
layout_params = {
'scale': scale,
'orig_dtype': orig_dtype
}
return tensor, layout_params
params = cls.Params(
scale=scale,
orig_dtype=orig_dtype,
orig_shape=orig_shape,
block_scale=block_scale,
)
return qdata, params
@staticmethod
def dequantize(qdata, scale, orig_dtype, **kwargs):
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
plain_tensor.mul_(scale)
return plain_tensor
@classmethod
def get_plain_tensors(cls, qtensor):
return qtensor._qdata, qtensor._layout_params['scale']
class TensorCoreFP8E4M3Layout(_TensorCoreFP8LayoutBase):
FP8_DTYPE = torch.float8_e4m3fn
class TensorCoreFP8E5M2Layout(_TensorCoreFP8LayoutBase):
FP8_DTYPE = torch.float8_e5m2
# Backward compatibility alias - default to E4M3
TensorCoreFP8Layout = TensorCoreFP8E4M3Layout
# ==============================================================================
# Registry
# ==============================================================================
register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout)
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
QUANT_ALGOS = {
"float8_e4m3fn": {
"storage_t": torch.float8_e4m3fn,
"parameters": {"weight_scale", "input_scale"},
"comfy_tensor_layout": "TensorCoreFP8Layout",
"comfy_tensor_layout": "TensorCoreFP8E4M3Layout",
},
"float8_e5m2": {
"storage_t": torch.float8_e5m2,
"parameters": {"weight_scale", "input_scale"},
"comfy_tensor_layout": "TensorCoreFP8E5M2Layout",
},
"nvfp4": {
"storage_t": torch.uint8,
"parameters": {"weight_scale", "weight_scale_2", "input_scale"},
"comfy_tensor_layout": "TensorCoreNVFP4Layout",
"group_size": 16,
},
}
LAYOUTS = {
"TensorCoreFP8Layout": TensorCoreFP8Layout,
}
# ==============================================================================
# Re-exports for backward compatibility
# ==============================================================================
@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout")
def fp8_linear(func, args, kwargs):
input_tensor = args[0]
weight = args[1]
bias = args[2] if len(args) > 2 else None
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
out_dtype = kwargs.get("out_dtype")
if out_dtype is None:
out_dtype = input_tensor._layout_params['orig_dtype']
weight_t = plain_weight.t()
tensor_2d = False
if len(plain_input.shape) == 2:
tensor_2d = True
plain_input = plain_input.unsqueeze(1)
input_shape = plain_input.shape
if len(input_shape) != 3:
return None
try:
output = torch._scaled_mm(
plain_input.reshape(-1, input_shape[2]).contiguous(),
weight_t,
bias=bias,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype,
)
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
output = output[0]
if not tensor_2d:
output = output.reshape((-1, input_shape[1], weight.shape[0]))
if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
output_scale = scale_a * scale_b
output_params = {
'scale': output_scale,
'orig_dtype': input_tensor._layout_params['orig_dtype']
}
return QuantizedTensor(output, "TensorCoreFP8Layout", output_params)
else:
return output
except Exception as e:
raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")
# Case 2: DQ Fallback
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
if isinstance(input_tensor, QuantizedTensor):
input_tensor = input_tensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight, bias)
def fp8_mm_(input_tensor, weight, bias=None, out_dtype=None):
if out_dtype is None:
out_dtype = input_tensor._layout_params['orig_dtype']
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
output = torch._scaled_mm(
plain_input.contiguous(),
plain_weight,
bias=bias,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype,
)
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
output = output[0]
return output
@register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout")
def fp8_addmm(func, args, kwargs):
input_tensor = args[1]
weight = args[2]
bias = args[0]
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None))
a = list(args)
if isinstance(args[0], QuantizedTensor):
a[0] = args[0].dequantize()
if isinstance(args[1], QuantizedTensor):
a[1] = args[1].dequantize()
if isinstance(args[2], QuantizedTensor):
a[2] = args[2].dequantize()
return func(*a, **kwargs)
@register_layout_op(torch.ops.aten.mm.default, "TensorCoreFP8Layout")
def fp8_mm(func, args, kwargs):
input_tensor = args[0]
weight = args[1]
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None))
a = list(args)
if isinstance(args[0], QuantizedTensor):
a[0] = args[0].dequantize()
if isinstance(args[1], QuantizedTensor):
a[1] = args[1].dequantize()
return func(*a, **kwargs)
@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
def fp8_func(func, args, kwargs):
input_tensor = args[0]
if isinstance(input_tensor, QuantizedTensor):
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
ar = list(args)
ar[0] = plain_input
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
return func(*args, **kwargs)
__all__ = [
"QuantizedTensor",
"QuantizedLayout",
"TensorCoreFP8Layout",
"TensorCoreFP8E4M3Layout",
"TensorCoreFP8E5M2Layout",
"TensorCoreNVFP4Layout",
"QUANT_ALGOS",
"register_layout_op",
]

View File

@ -218,7 +218,7 @@ class CLIP:
if unprojected:
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model()
self.load_model(tokens)
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
all_hooks.reset()
self.patcher.patch_hooks(None)
@ -266,7 +266,7 @@ class CLIP:
if return_pooled == "unprojected":
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model()
self.load_model(tokens)
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
o = self.cond_stage_model.encode_token_weights(tokens)
cond, pooled = o[:2]
@ -299,8 +299,11 @@ class CLIP:
sd_clip[k] = sd_tokenizer[k]
return sd_clip
def load_model(self):
model_management.load_model_gpu(self.patcher)
def load_model(self, tokens={}):
memory_used = 0
if hasattr(self.cond_stage_model, "memory_estimation_function"):
memory_used = self.cond_stage_model.memory_estimation_function(tokens, device=self.patcher.load_device)
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
return self.patcher
def get_key_patches(self):
@ -476,8 +479,8 @@ class VAE:
self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE(version=version, config=vae_config)
self.latent_channels = 128
self.latent_dim = 3
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (1200 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (80 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32)
self.upscale_index_formula = (8, 32, 32)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
@ -1041,7 +1044,8 @@ class TEModel(Enum):
MISTRAL3_24B_PRUNED_FLUX2 = 15
QWEN3_4B = 16
QWEN3_2B = 17
JINA_CLIP_2 = 18
GEMMA_3_12B = 18
JINA_CLIP_2 = 19
def detect_te_model(sd):
@ -1055,9 +1059,9 @@ def detect_te_model(sd):
return TEModel.JINA_CLIP_2
if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
if weight.shape[-1] == 4096:
if weight.shape[0] == 10240:
return TEModel.T5_XXL
elif weight.shape[-1] == 2048:
elif weight.shape[0] == 5120:
return TEModel.T5_XL
if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd:
return TEModel.T5_XXL_OLD
@ -1067,6 +1071,8 @@ def detect_te_model(sd):
return TEModel.BYT5_SMALL_GLYPH
return TEModel.T5_BASE
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
if 'model.layers.47.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_3_12B
if 'model.layers.0.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_3_4B
return TEModel.GEMMA_2_2B
@ -1271,6 +1277,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif clip_type == CLIPType.KANDINSKY5_IMAGE:
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
elif clip_type == CLIPType.LTXV:
clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lt.LTXAVGemmaTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif clip_type == CLIPType.NEWBIE:
clip_target.clip = comfy.text_encoders.newbie.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.newbie.NewBieTokenizer

View File

@ -836,6 +836,21 @@ class LTXV(supported_models_base.BASE):
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.lt.LTXVT5Tokenizer, comfy.text_encoders.lt.ltxv_te(**t5_detect))
class LTXAV(LTXV):
unet_config = {
"image_model": "ltxav",
}
latent_format = latent_formats.LTXAV
def __init__(self, unet_config):
super().__init__(unet_config)
self.memory_usage_factor = 0.077 # TODO
def get_model(self, state_dict, prefix="", device=None):
out = model_base.LTXAV(self, device=device)
return out
class HunyuanVideo(supported_models_base.BASE):
unet_config = {
"image_model": "hunyuan_video",
@ -1536,6 +1551,6 @@ class Kandinsky5Image(Kandinsky5):
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5]
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5]
models += [SVD_img2vid]

View File

@ -154,7 +154,8 @@ class TAEHV(nn.Module):
self._show_progress_bar = value
def encode(self, x, **kwargs):
if self.patch_size > 1: x = F.pixel_unshuffle(x, self.patch_size)
if self.patch_size > 1:
x = F.pixel_unshuffle(x, self.patch_size)
x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
if x.shape[1] % 4 != 0:
# pad at end to multiple of 4
@ -167,5 +168,6 @@ class TAEHV(nn.Module):
def decode(self, x, **kwargs):
x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar)
if self.patch_size > 1: x = F.pixel_shuffle(x, self.patch_size)
if self.patch_size > 1:
x = F.pixel_shuffle(x, self.patch_size)
return x[:, self.frames_to_trim:].movedim(2, 1)

View File

@ -36,7 +36,7 @@ def te(dtype_t5=None, t5_quantization_metadata=None):
if t5_quantization_metadata is not None:
model_options = model_options.copy()
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
if dtype is None:
if dtype_t5 is not None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)
return CosmosTEModel_

View File

@ -32,7 +32,7 @@ def mochi_te(dtype_t5=None, t5_quantization_metadata=None):
if t5_quantization_metadata is not None:
model_options = model_options.copy()
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
if dtype is None:
if dtype_t5 is not None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)
return MochiTEModel_

View File

@ -7,6 +7,7 @@ import math
from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.model_management
import comfy.ldm.common_dit
import comfy.clip_model
from . import qwen_vl
@ -188,6 +189,31 @@ class Gemma3_4B_Config:
rope_scale = [8.0, 1.0]
final_norm: bool = True
@dataclass
class Gemma3_12B_Config:
vocab_size: int = 262208
hidden_size: int = 3840
intermediate_size: int = 15360
num_hidden_layers: int = 48
num_attention_heads: int = 16
num_key_value_heads: int = 8
max_position_embeddings: int = 131072
rms_norm_eps: float = 1e-6
rope_theta = [1000000.0, 10000.0]
transformer_type: str = "gemma3"
head_dim = 256
rms_norm_add = True
mlp_activation = "gelu_pytorch_tanh"
qkv_bias = False
rope_dims = None
q_norm = "gemma3"
k_norm = "gemma3"
sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
rope_scale = [8.0, 1.0]
final_norm: bool = True
vision_config = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}
mm_tokens_per_image = 256
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
super().__init__()
@ -520,6 +546,41 @@ class Llama2_(nn.Module):
return x, intermediate
class Gemma3MultiModalProjector(torch.nn.Module):
def __init__(self, config, dtype, device, operations):
super().__init__()
self.mm_input_projection_weight = nn.Parameter(
torch.empty(config.vision_config["hidden_size"], config.hidden_size, device=device, dtype=dtype)
)
self.mm_soft_emb_norm = RMSNorm(config.vision_config["hidden_size"], eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
self.patches_per_image = int(config.vision_config["image_size"] // config.vision_config["patch_size"])
self.tokens_per_side = int(config.mm_tokens_per_image**0.5)
self.kernel_size = self.patches_per_image // self.tokens_per_side
self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size)
def forward(self, vision_outputs: torch.Tensor):
batch_size, _, seq_length = vision_outputs.shape
reshaped_vision_outputs = vision_outputs.transpose(1, 2)
reshaped_vision_outputs = reshaped_vision_outputs.reshape(
batch_size, seq_length, self.patches_per_image, self.patches_per_image
)
reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
pooled_vision_outputs = pooled_vision_outputs.flatten(2)
pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2)
normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs)
projected_vision_outputs = torch.matmul(normed_vision_outputs, comfy.model_management.cast_to_device(self.mm_input_projection_weight, device=normed_vision_outputs.device, dtype=normed_vision_outputs.dtype))
return projected_vision_outputs.type_as(vision_outputs)
class BaseLlama:
def get_input_embeddings(self):
return self.model.embed_tokens
@ -636,3 +697,21 @@ class Gemma3_4B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Gemma3_12B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Gemma3_12B_Config(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.multi_modal_projector = Gemma3MultiModalProjector(config, dtype, device, operations)
self.vision_model = comfy.clip_model.CLIPVision(config.vision_config, dtype, device, operations)
self.dtype = dtype
self.image_size = config.vision_config["image_size"]
def preprocess_embed(self, embed, device):
if embed["type"] == "image":
image = comfy.clip_model.clip_preprocess(embed["data"], size=self.image_size, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True)
return self.multi_modal_projector(self.vision_model(image.to(device, dtype=torch.float32))[0]), None
return None, None

View File

@ -1,7 +1,11 @@
from comfy import sd1_clip
import os
from transformers import T5TokenizerFast
from .spiece_tokenizer import SPieceTokenizer
import comfy.text_encoders.genmo
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
import torch
import comfy.utils
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
@ -16,3 +20,123 @@ class LTXVT5Tokenizer(sd1_clip.SD1Tokenizer):
def ltxv_te(*args, **kwargs):
return comfy.text_encoders.genmo.mochi_te(*args, **kwargs)
class Gemma3_12BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_12b", tokenizer=Gemma3_12BTokenizer)
class Gemma3_12BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
def tokenize_with_weights(self, text, return_word_ids=False, llama_template="{}", image_embeds=None, **kwargs):
text = llama_template.format(text)
text_tokens = super().tokenize_with_weights(text, return_word_ids)
embed_count = 0
for k in text_tokens:
tt = text_tokens[k]
for r in tt:
for i in range(len(r)):
if r[i][0] == 262144:
if image_embeds is not None and embed_count < image_embeds.shape[0]:
r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image"},) + r[i][1:]
embed_count += 1
return text_tokens
class LTXAVTEModel(torch.nn.Module):
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
super().__init__()
self.dtypes = set()
self.dtypes.add(dtype)
self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None)
self.dtypes.add(dtype_llama)
operations = self.gemma3_12b.operations # TODO
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
self.audio_embeddings_connector = Embeddings1DConnector(
split_rope=True,
double_precision_rope=True,
dtype=dtype,
device=device,
operations=operations,
)
self.video_embeddings_connector = Embeddings1DConnector(
split_rope=True,
double_precision_rope=True,
dtype=dtype,
device=device,
operations=operations,
)
def set_clip_options(self, options):
self.execution_device = options.get("execution_device", self.execution_device)
self.gemma3_12b.set_clip_options(options)
def reset_clip_options(self):
self.gemma3_12b.reset_clip_options()
self.execution_device = None
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs = token_weight_pairs["gemma3_12b"]
out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs)
out_device = out.device
if comfy.model_management.should_use_bf16(self.execution_device):
out = out.to(device=self.execution_device, dtype=torch.bfloat16)
out = out.movedim(1, -1).to(self.execution_device)
out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6)
out = out.reshape((out.shape[0], out.shape[1], -1))
out = self.text_embedding_projection(out)
out = out.float()
out_vid = self.video_embeddings_connector(out)[0]
out_audio = self.audio_embeddings_connector(out)[0]
out = torch.concat((out_vid, out_audio), dim=-1)
return out.to(out_device), pooled
def load_sd(self, sd):
if "model.layers.47.self_attn.q_norm.weight" in sd:
return self.gemma3_12b.load_sd(sd)
else:
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True)
if len(sdo) == 0:
sdo = sd
return self.load_state_dict(sdo, strict=False)
def memory_estimation_function(self, token_weight_pairs, device=None):
constant = 6.0
if comfy.model_management.should_use_bf16(device):
constant /= 2.0
token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
return num_tokens * constant * 1024 * 1024
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
class LTXAVTEModel_(LTXAVTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["llama_quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
return LTXAVTEModel_

View File

@ -36,7 +36,7 @@ def pixart_te(dtype_t5=None, t5_quantization_metadata=None):
if t5_quantization_metadata is not None:
model_options = model_options.copy()
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
if dtype is None:
if dtype_t5 is not None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)
return PixArtTEModel_

View File

@ -1198,7 +1198,7 @@ def unpack_latents(combined_latent, latent_shapes):
combined_latent = combined_latent[:, :, cut:]
output_tensors.append(tens.reshape([tens.shape[0]] + list(shape)[1:]))
else:
output_tensors = combined_latent
output_tensors = [combined_latent]
return output_tensors
def detect_layer_quantization(state_dict, prefix):
@ -1230,6 +1230,8 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
out_sd = {}
layers = {}
for k in list(state_dict.keys()):
if k == scaled_fp8_key:
continue
if not k.startswith(model_prefix):
out_sd[k] = state_dict[k]
continue

View File

@ -1113,6 +1113,18 @@ class DynamicSlot(ComfyTypeI):
out_dict[input_type][finalized_id] = value
out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1])
@comfytype(io_type="IMAGECOMPARE")
class ImageCompare(ComfyTypeI):
Type = dict
class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
socketless: bool=True):
super().__init__(id, display_name, optional, tooltip, None, None, socketless)
def as_dict(self):
return super().as_dict()
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
DYNAMIC_INPUT_LOOKUP[io_type] = func
@ -1213,6 +1225,7 @@ class NodeInfoV1:
deprecated: bool=None
experimental: bool=None
api_node: bool=None
price_badge: dict | None = None
@dataclass
class NodeInfoV3:
@ -1222,11 +1235,77 @@ class NodeInfoV3:
name: str=None
display_name: str=None
description: str=None
python_module: Any = None
category: str=None
output_node: bool=None
deprecated: bool=None
experimental: bool=None
api_node: bool=None
price_badge: dict | None = None
@dataclass
class PriceBadgeDepends:
widgets: list[str] = field(default_factory=list)
inputs: list[str] = field(default_factory=list)
input_groups: list[str] = field(default_factory=list)
def validate(self) -> None:
if not isinstance(self.widgets, list) or any(not isinstance(x, str) for x in self.widgets):
raise ValueError("PriceBadgeDepends.widgets must be a list[str].")
if not isinstance(self.inputs, list) or any(not isinstance(x, str) for x in self.inputs):
raise ValueError("PriceBadgeDepends.inputs must be a list[str].")
if not isinstance(self.input_groups, list) or any(not isinstance(x, str) for x in self.input_groups):
raise ValueError("PriceBadgeDepends.input_groups must be a list[str].")
def as_dict(self, schema_inputs: list["Input"]) -> dict[str, Any]:
# Build lookup: widget_id -> io_type
input_types: dict[str, str] = {}
for inp in schema_inputs:
all_inputs = inp.get_all()
input_types[inp.id] = inp.get_io_type() # First input is always the parent itself
for nested_inp in all_inputs[1:]:
# For DynamicCombo/DynamicSlot, nested inputs are prefixed with parent ID
# to match frontend naming convention (e.g., "should_texture.enable_pbr")
prefixed_id = f"{inp.id}.{nested_inp.id}"
input_types[prefixed_id] = nested_inp.get_io_type()
# Enrich widgets with type information, raising error for unknown widgets
widgets_data: list[dict[str, str]] = []
for w in self.widgets:
if w not in input_types:
raise ValueError(
f"PriceBadge depends_on.widgets references unknown widget '{w}'. "
f"Available widgets: {list(input_types.keys())}"
)
widgets_data.append({"name": w, "type": input_types[w]})
return {
"widgets": widgets_data,
"inputs": self.inputs,
"input_groups": self.input_groups,
}
@dataclass
class PriceBadge:
expr: str
depends_on: PriceBadgeDepends = field(default_factory=PriceBadgeDepends)
engine: str = field(default="jsonata")
def validate(self) -> None:
if self.engine != "jsonata":
raise ValueError(f"Unsupported PriceBadge.engine '{self.engine}'. Only 'jsonata' is supported.")
if not isinstance(self.expr, str) or not self.expr.strip():
raise ValueError("PriceBadge.expr must be a non-empty string.")
self.depends_on.validate()
def as_dict(self, schema_inputs: list["Input"]) -> dict[str, Any]:
return {
"engine": self.engine,
"depends_on": self.depends_on.as_dict(schema_inputs),
"expr": self.expr,
}
@dataclass
@ -1272,6 +1351,8 @@ class Schema:
"""Flags a node as experimental, informing users that it may change or not work as expected."""
is_api_node: bool=False
"""Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""
price_badge: PriceBadge | None = None
"""Optional client-evaluated pricing badge declaration for this node."""
not_idempotent: bool=False
"""Flags a node as not idempotent; when True, the node will run and not reuse the cached outputs when identical inputs are provided on a different node in the graph."""
enable_expand: bool=False
@ -1302,6 +1383,8 @@ class Schema:
input.validate()
for output in self.outputs:
output.validate()
if self.price_badge is not None:
self.price_badge.validate()
def finalize(self):
"""Add hidden based on selected schema options, and give outputs without ids default ids."""
@ -1375,7 +1458,8 @@ class Schema:
deprecated=self.is_deprecated,
experimental=self.is_experimental,
api_node=self.is_api_node,
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes")
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"),
price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None,
)
return info
@ -1407,7 +1491,8 @@ class Schema:
deprecated=self.is_deprecated,
experimental=self.is_experimental,
api_node=self.is_api_node,
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes")
python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"),
price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None,
)
return info
@ -1958,4 +2043,7 @@ __all__ = [
"add_to_dict_v1",
"add_to_dict_v3",
"V3Data",
"ImageCompare",
"PriceBadgeDepends",
"PriceBadge",
]

View File

@ -0,0 +1,41 @@
from pydantic import BaseModel, Field
class SubjectReference(BaseModel):
id: str = Field(...)
images: list[str] = Field(...)
class TaskCreationRequest(BaseModel):
model: str = Field(...)
prompt: str = Field(..., max_length=2000)
duration: int = Field(...)
seed: int = Field(..., ge=0, le=2147483647)
aspect_ratio: str | None = Field(None)
resolution: str | None = Field(None)
movement_amplitude: str | None = Field(None)
images: list[str] | None = Field(None, description="Base64 encoded string or image URL")
subjects: list[SubjectReference] | None = Field(None)
bgm: bool | None = Field(None)
audio: bool | None = Field(None)
class TaskCreationResponse(BaseModel):
task_id: str = Field(...)
state: str = Field(...)
created_at: str = Field(...)
code: int | None = Field(None, description="Error code")
class TaskResult(BaseModel):
id: str = Field(..., description="Creation id")
url: str = Field(..., description="The URL of the generated results, valid for one hour")
cover_url: str = Field(..., description="The cover URL of the generated results, valid for one hour")
class TaskStatusResponse(BaseModel):
state: str = Field(...)
err_code: str | None = Field(None)
progress: float | None = Field(None)
credits: int | None = Field(None)
creations: list[TaskResult] = Field(..., description="Generated results")

View File

@ -97,6 +97,9 @@ class FluxProUltraImageNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.06}""",
),
)
@classmethod
@ -352,6 +355,9 @@ class FluxProExpandNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.05}""",
),
)
@classmethod
@ -458,6 +464,9 @@ class FluxProFillNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.05}""",
),
)
@classmethod
@ -511,6 +520,21 @@ class Flux2ProImageNode(IO.ComfyNode):
NODE_ID = "Flux2ProImageNode"
DISPLAY_NAME = "Flux.2 [pro] Image"
API_ENDPOINT = "/proxy/bfl/flux-2-pro/generate"
PRICE_BADGE_EXPR = """
(
$MP := 1024 * 1024;
$outMP := $max([1, $floor(((widgets.width * widgets.height) + $MP - 1) / $MP)]);
$outputCost := 0.03 + 0.015 * ($outMP - 1);
inputs.images.connected
? {
"type":"range_usd",
"min_usd": $outputCost + 0.015,
"max_usd": $outputCost + 0.12,
"format": { "approximate": true }
}
: {"type":"usd","usd": $outputCost}
)
"""
@classmethod
def define_schema(cls) -> IO.Schema:
@ -563,6 +587,10 @@ class Flux2ProImageNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["width", "height"], inputs=["images"]),
expr=cls.PRICE_BADGE_EXPR,
),
)
@classmethod
@ -623,6 +651,22 @@ class Flux2MaxImageNode(Flux2ProImageNode):
NODE_ID = "Flux2MaxImageNode"
DISPLAY_NAME = "Flux.2 [max] Image"
API_ENDPOINT = "/proxy/bfl/flux-2-max/generate"
PRICE_BADGE_EXPR = """
(
$MP := 1024 * 1024;
$outMP := $max([1, $floor(((widgets.width * widgets.height) + $MP - 1) / $MP)]);
$outputCost := 0.07 + 0.03 * ($outMP - 1);
inputs.images.connected
? {
"type":"range_usd",
"min_usd": $outputCost + 0.03,
"max_usd": $outputCost + 0.24,
"format": { "approximate": true }
}
: {"type":"usd","usd": $outputCost}
)
"""
class BFLExtension(ComfyExtension):

View File

@ -126,6 +126,9 @@ class ByteDanceImageNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.03}""",
),
)
@classmethod
@ -367,6 +370,19 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
expr="""
(
$price := $contains(widgets.model, "seedream-4-5-251128") ? 0.04 : 0.03;
{
"type":"usd",
"usd": $price,
"format": { "suffix":" x images/Run", "approximate": true }
}
)
""",
),
)
@classmethod
@ -522,6 +538,7 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=PRICE_BADGE_VIDEO,
)
@classmethod
@ -632,6 +649,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=PRICE_BADGE_VIDEO,
)
@classmethod
@ -754,6 +772,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=PRICE_BADGE_VIDEO,
)
@classmethod
@ -877,6 +896,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=PRICE_BADGE_VIDEO,
)
@classmethod
@ -946,6 +966,52 @@ def raise_if_text_params(prompt: str, text_params: list[str]) -> None:
)
PRICE_BADGE_VIDEO = IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"]),
expr="""
(
$priceByModel := {
"seedance-1-0-pro": {
"480p":[0.23,0.24],
"720p":[0.51,0.56],
"1080p":[1.18,1.22]
},
"seedance-1-0-pro-fast": {
"480p":[0.09,0.1],
"720p":[0.21,0.23],
"1080p":[0.47,0.49]
},
"seedance-1-0-lite": {
"480p":[0.17,0.18],
"720p":[0.37,0.41],
"1080p":[0.85,0.88]
}
};
$model := widgets.model;
$modelKey :=
$contains($model, "seedance-1-0-pro-fast") ? "seedance-1-0-pro-fast" :
$contains($model, "seedance-1-0-pro") ? "seedance-1-0-pro" :
"seedance-1-0-lite";
$resolution := widgets.resolution;
$resKey :=
$contains($resolution, "1080") ? "1080p" :
$contains($resolution, "720") ? "720p" :
"480p";
$modelPrices := $lookup($priceByModel, $modelKey);
$baseRange := $lookup($modelPrices, $resKey);
$min10s := $baseRange[0];
$max10s := $baseRange[1];
$scale := widgets.duration / 10;
$minCost := $min10s * $scale;
$maxCost := $max10s * $scale;
($minCost = $maxCost)
? {"type":"usd","usd": $minCost}
: {"type":"range_usd","min_usd": $minCost, "max_usd": $maxCost}
)
""",
)
class ByteDanceExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:

View File

@ -130,7 +130,7 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera
Returns:
List of response parts matching the requested type.
"""
if response.candidates is None:
if not response.candidates:
if response.promptFeedback and response.promptFeedback.blockReason:
feedback = response.promptFeedback
raise ValueError(
@ -141,14 +141,24 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera
"try changing it to `IMAGE+TEXT` to view the model's reasoning and understand why image generation failed."
)
parts = []
for part in response.candidates[0].content.parts:
if part_type == "text" and part.text:
parts.append(part)
elif part.inlineData and part.inlineData.mimeType == part_type:
parts.append(part)
elif part.fileData and part.fileData.mimeType == part_type:
parts.append(part)
# Skip parts that don't match the requested type
blocked_reasons = []
for candidate in response.candidates:
if candidate.finishReason and candidate.finishReason.upper() == "IMAGE_PROHIBITED_CONTENT":
blocked_reasons.append(candidate.finishReason)
continue
if candidate.content is None or candidate.content.parts is None:
continue
for part in candidate.content.parts:
if part_type == "text" and part.text:
parts.append(part)
elif part.inlineData and part.inlineData.mimeType == part_type:
parts.append(part)
elif part.fileData and part.fileData.mimeType == part_type:
parts.append(part)
if not parts and blocked_reasons:
raise ValueError(f"Gemini API blocked the request. Reasons: {blocked_reasons}")
return parts
@ -309,6 +319,30 @@ class GeminiNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
expr="""
(
$m := widgets.model;
$contains($m, "gemini-2.5-flash") ? {
"type": "list_usd",
"usd": [0.0003, 0.0025],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens"}
}
: $contains($m, "gemini-2.5-pro") ? {
"type": "list_usd",
"usd": [0.00125, 0.01],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "gemini-3-pro-preview") ? {
"type": "list_usd",
"usd": [0.002, 0.012],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: {"type":"text", "text":"Token-based"}
)
""",
),
)
@classmethod
@ -570,6 +604,9 @@ class GeminiImage(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.039,"format":{"suffix":"/Image (1K)","approximate":true}}""",
),
)
@classmethod
@ -700,6 +737,19 @@ class GeminiImage2(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["resolution"]),
expr="""
(
$r := widgets.resolution;
($contains($r,"1k") or $contains($r,"2k"))
? {"type":"usd","usd":0.134,"format":{"suffix":"/Image","approximate":true}}
: $contains($r,"4k")
? {"type":"usd","usd":0.24,"format":{"suffix":"/Image","approximate":true}}
: {"type":"text","text":"Token-based"}
)
""",
),
)
@classmethod

View File

@ -236,7 +236,6 @@ class IdeogramV1(IO.ComfyNode):
display_name="Ideogram V1",
category="api node/image/Ideogram",
description="Generates images using the Ideogram V1 model.",
is_api_node=True,
inputs=[
IO.String.Input(
"prompt",
@ -298,6 +297,17 @@ class IdeogramV1(IO.ComfyNode):
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["num_images", "turbo"]),
expr="""
(
$n := widgets.num_images;
$base := (widgets.turbo = true) ? 0.0286 : 0.0858;
{"type":"usd","usd": $round($base * $n, 2)}
)
""",
),
)
@classmethod
@ -351,7 +361,6 @@ class IdeogramV2(IO.ComfyNode):
display_name="Ideogram V2",
category="api node/image/Ideogram",
description="Generates images using the Ideogram V2 model.",
is_api_node=True,
inputs=[
IO.String.Input(
"prompt",
@ -436,6 +445,17 @@ class IdeogramV2(IO.ComfyNode):
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["num_images", "turbo"]),
expr="""
(
$n := widgets.num_images;
$base := (widgets.turbo = true) ? 0.0715 : 0.1144;
{"type":"usd","usd": $round($base * $n, 2)}
)
""",
),
)
@classmethod
@ -506,7 +526,6 @@ class IdeogramV3(IO.ComfyNode):
category="api node/image/Ideogram",
description="Generates images using the Ideogram V3 model. "
"Supports both regular image generation from text prompts and image editing with mask.",
is_api_node=True,
inputs=[
IO.String.Input(
"prompt",
@ -591,6 +610,23 @@ class IdeogramV3(IO.ComfyNode):
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["rendering_speed", "num_images"], inputs=["character_image"]),
expr="""
(
$n := widgets.num_images;
$speed := widgets.rendering_speed;
$hasChar := inputs.character_image.connected;
$base :=
$contains($speed,"quality") ? ($hasChar ? 0.286 : 0.1287) :
$contains($speed,"default") ? ($hasChar ? 0.2145 : 0.0858) :
$contains($speed,"turbo") ? ($hasChar ? 0.143 : 0.0429) :
0.0858;
{"type":"usd","usd": $round($base * $n, 2)}
)
""",
),
)
@classmethod

View File

@ -567,7 +567,7 @@ async def execute_lipsync(
# Upload the audio file to Comfy API and get download URL
if audio:
audio_url = await upload_audio_to_comfyapi(
cls, audio, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mpeg", filename="output.mp3"
cls, audio, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mpeg"
)
logging.info("Uploaded audio to Comfy API. URL: %s", audio_url)
else:
@ -764,6 +764,33 @@ class KlingTextToVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["mode"]),
expr="""
(
$m := widgets.mode;
$contains($m,"v2-5-turbo")
? ($contains($m,"10") ? {"type":"usd","usd":0.7} : {"type":"usd","usd":0.35})
: $contains($m,"v2-1-master")
? ($contains($m,"10s") ? {"type":"usd","usd":2.8} : {"type":"usd","usd":1.4})
: $contains($m,"v2-master")
? ($contains($m,"10s") ? {"type":"usd","usd":2.8} : {"type":"usd","usd":1.4})
: $contains($m,"v1-6")
? (
$contains($m,"pro")
? ($contains($m,"10s") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49})
: ($contains($m,"10s") ? {"type":"usd","usd":0.56} : {"type":"usd","usd":0.28})
)
: $contains($m,"v1")
? (
$contains($m,"pro")
? ($contains($m,"10s") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49})
: ($contains($m,"10s") ? {"type":"usd","usd":0.28} : {"type":"usd","usd":0.14})
)
: {"type":"usd","usd":0.14}
)
""",
),
)
@classmethod
@ -807,6 +834,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
),
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
IO.Combo.Input("duration", options=[5, 10]),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
],
outputs=[
IO.Video.Output(),
@ -817,6 +845,16 @@ class OmniProTextToVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]),
expr="""
(
$mode := (widgets.resolution = "720p") ? "std" : "pro";
$rates := {"std": 0.084, "pro": 0.112};
{"type":"usd","usd": $lookup($rates, $mode) * widgets.duration}
)
""",
),
)
@classmethod
@ -826,6 +864,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
prompt: str,
aspect_ratio: str,
duration: int,
resolution: str = "1080p",
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2500)
response = await sync_op(
@ -837,6 +876,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
prompt=prompt,
aspect_ratio=aspect_ratio,
duration=str(duration),
mode="pro" if resolution == "1080p" else "std",
),
)
return await finish_omni_video_task(cls, response)
@ -872,6 +912,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
optional=True,
tooltip="Up to 6 additional reference images.",
),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
],
outputs=[
IO.Video.Output(),
@ -882,6 +923,16 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]),
expr="""
(
$mode := (widgets.resolution = "720p") ? "std" : "pro";
$rates := {"std": 0.084, "pro": 0.112};
{"type":"usd","usd": $lookup($rates, $mode) * widgets.duration}
)
""",
),
)
@classmethod
@ -893,6 +944,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
first_frame: Input.Image,
end_frame: Input.Image | None = None,
reference_images: Input.Image | None = None,
resolution: str = "1080p",
) -> IO.NodeOutput:
prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500)
@ -936,6 +988,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
prompt=prompt,
duration=str(duration),
image_list=image_list,
mode="pro" if resolution == "1080p" else "std",
),
)
return await finish_omni_video_task(cls, response)
@ -964,6 +1017,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
"reference_images",
tooltip="Up to 7 reference images.",
),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
],
outputs=[
IO.Video.Output(),
@ -974,6 +1028,16 @@ class OmniProImageToVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]),
expr="""
(
$mode := (widgets.resolution = "720p") ? "std" : "pro";
$rates := {"std": 0.084, "pro": 0.112};
{"type":"usd","usd": $lookup($rates, $mode) * widgets.duration}
)
""",
),
)
@classmethod
@ -984,6 +1048,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
aspect_ratio: str,
duration: int,
reference_images: Input.Image,
resolution: str = "1080p",
) -> IO.NodeOutput:
prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500)
@ -1005,6 +1070,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
aspect_ratio=aspect_ratio,
duration=str(duration),
image_list=image_list,
mode="pro" if resolution == "1080p" else "std",
),
)
return await finish_omni_video_task(cls, response)
@ -1036,6 +1102,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
tooltip="Up to 4 additional reference images.",
optional=True,
),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
],
outputs=[
IO.Video.Output(),
@ -1046,6 +1113,16 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]),
expr="""
(
$mode := (widgets.resolution = "720p") ? "std" : "pro";
$rates := {"std": 0.126, "pro": 0.168};
{"type":"usd","usd": $lookup($rates, $mode) * widgets.duration}
)
""",
),
)
@classmethod
@ -1058,6 +1135,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
reference_video: Input.Video,
keep_original_sound: bool,
reference_images: Input.Image | None = None,
resolution: str = "1080p",
) -> IO.NodeOutput:
prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500)
@ -1090,6 +1168,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
duration=str(duration),
image_list=image_list if image_list else None,
video_list=video_list,
mode="pro" if resolution == "1080p" else "std",
),
)
return await finish_omni_video_task(cls, response)
@ -1119,6 +1198,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
tooltip="Up to 4 additional reference images.",
optional=True,
),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
],
outputs=[
IO.Video.Output(),
@ -1129,6 +1209,16 @@ class OmniProEditVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["resolution"]),
expr="""
(
$mode := (widgets.resolution = "720p") ? "std" : "pro";
$rates := {"std": 0.126, "pro": 0.168};
{"type":"usd","usd": $lookup($rates, $mode), "format":{"suffix":"/second"}}
)
""",
),
)
@classmethod
@ -1139,6 +1229,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
video: Input.Video,
keep_original_sound: bool,
reference_images: Input.Image | None = None,
resolution: str = "1080p",
) -> IO.NodeOutput:
prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500)
@ -1171,6 +1262,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
duration=None,
image_list=image_list if image_list else None,
video_list=video_list,
mode="pro" if resolution == "1080p" else "std",
),
)
return await finish_omni_video_task(cls, response)
@ -1213,6 +1305,9 @@ class OmniProImageNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.028}""",
),
)
@classmethod
@ -1298,6 +1393,9 @@ class KlingCameraControlT2VNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.14}""",
),
)
@classmethod
@ -1360,6 +1458,33 @@ class KlingImage2VideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["mode", "model_name", "duration"]),
expr="""
(
$mode := widgets.mode;
$model := widgets.model_name;
$dur := widgets.duration;
$contains($model,"v2-5-turbo")
? ($contains($dur,"10") ? {"type":"usd","usd":0.7} : {"type":"usd","usd":0.35})
: ($contains($model,"v2-1-master") or $contains($model,"v2-master"))
? ($contains($dur,"10") ? {"type":"usd","usd":2.8} : {"type":"usd","usd":1.4})
: ($contains($model,"v2-1") or $contains($model,"v1-6") or $contains($model,"v1-5"))
? (
$contains($mode,"pro")
? ($contains($dur,"10") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49})
: ($contains($dur,"10") ? {"type":"usd","usd":0.56} : {"type":"usd","usd":0.28})
)
: $contains($model,"v1")
? (
$contains($mode,"pro")
? ($contains($dur,"10") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49})
: ($contains($dur,"10") ? {"type":"usd","usd":0.28} : {"type":"usd","usd":0.14})
)
: {"type":"usd","usd":0.14}
)
""",
),
)
@classmethod
@ -1433,6 +1558,9 @@ class KlingCameraControlI2VNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.49}""",
),
)
@classmethod
@ -1503,6 +1631,33 @@ class KlingStartEndFrameNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["mode"]),
expr="""
(
$m := widgets.mode;
$contains($m,"v2-5-turbo")
? ($contains($m,"10") ? {"type":"usd","usd":0.7} : {"type":"usd","usd":0.35})
: $contains($m,"v2-1")
? ($contains($m,"10s") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49})
: $contains($m,"v2-master")
? ($contains($m,"10s") ? {"type":"usd","usd":2.8} : {"type":"usd","usd":1.4})
: $contains($m,"v1-6")
? (
$contains($m,"pro")
? ($contains($m,"10s") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49})
: ($contains($m,"10s") ? {"type":"usd","usd":0.56} : {"type":"usd","usd":0.28})
)
: $contains($m,"v1")
? (
$contains($m,"pro")
? ($contains($m,"10s") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49})
: ($contains($m,"10s") ? {"type":"usd","usd":0.28} : {"type":"usd","usd":0.14})
)
: {"type":"usd","usd":0.14}
)
""",
),
)
@classmethod
@ -1568,6 +1723,9 @@ class KlingVideoExtendNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.28}""",
),
)
@classmethod
@ -1649,6 +1807,29 @@ class KlingDualCharacterVideoEffectNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["mode", "model_name", "duration"]),
expr="""
(
$mode := widgets.mode;
$model := widgets.model_name;
$dur := widgets.duration;
($contains($model,"v1-6") or $contains($model,"v1-5"))
? (
$contains($mode,"pro")
? ($contains($dur,"10") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49})
: ($contains($dur,"10") ? {"type":"usd","usd":0.56} : {"type":"usd","usd":0.28})
)
: $contains($model,"v1")
? (
$contains($mode,"pro")
? ($contains($dur,"10") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49})
: ($contains($dur,"10") ? {"type":"usd","usd":0.28} : {"type":"usd","usd":0.14})
)
: {"type":"usd","usd":0.14}
)
""",
),
)
@classmethod
@ -1713,6 +1894,16 @@ class KlingSingleImageVideoEffectNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["effect_scene"]),
expr="""
(
($contains(widgets.effect_scene,"dizzydizzy") or $contains(widgets.effect_scene,"bloombloom"))
? {"type":"usd","usd":0.49}
: {"type":"usd","usd":0.28}
)
""",
),
)
@classmethod
@ -1767,6 +1958,9 @@ class KlingLipSyncAudioToVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.1,"format":{"approximate":true}}""",
),
)
@classmethod
@ -1827,6 +2021,9 @@ class KlingLipSyncTextToVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.1,"format":{"approximate":true}}""",
),
)
@classmethod
@ -1877,6 +2074,9 @@ class KlingVirtualTryOnNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.7}""",
),
)
@classmethod
@ -1976,6 +2176,19 @@ class KlingImageGenerationNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model_name", "n"], inputs=["image"]),
expr="""
(
$m := widgets.model_name;
$base :=
$contains($m,"kling-v1-5")
? (inputs.image.connected ? 0.028 : 0.014)
: ($contains($m,"kling-v1") ? 0.0035 : 0.014);
{"type":"usd","usd": $base * widgets.n}
)
""",
),
)
@classmethod
@ -2059,6 +2272,10 @@ class TextToVideoWithAudio(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["duration", "generate_audio"]),
expr="""{"type":"usd","usd": 0.07 * widgets.duration * (widgets.generate_audio ? 2 : 1)}""",
),
)
@classmethod
@ -2123,6 +2340,10 @@ class ImageToVideoWithAudio(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["duration", "generate_audio"]),
expr="""{"type":"usd","usd": 0.07 * widgets.duration * (widgets.generate_audio ? 2 : 1)}""",
),
)
@classmethod
@ -2203,6 +2424,15 @@ class MotionControl(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["mode"]),
expr="""
(
$prices := {"std": 0.07, "pro": 0.112};
{"type":"usd","usd": $lookup($prices, widgets.mode), "format":{"suffix":"/second"}}
)
""",
),
)
@classmethod

View File

@ -28,6 +28,22 @@ class ExecuteTaskRequest(BaseModel):
image_uri: str | None = Field(None)
PRICE_BADGE = IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"]),
expr="""
(
$prices := {
"ltx-2 (pro)": {"1920x1080":0.06,"2560x1440":0.12,"3840x2160":0.24},
"ltx-2 (fast)": {"1920x1080":0.04,"2560x1440":0.08,"3840x2160":0.16}
};
$modelPrices := $lookup($prices, $lowercase(widgets.model));
$pps := $lookup($modelPrices, widgets.resolution);
{"type":"usd","usd": $pps * widgets.duration}
)
""",
)
class TextToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
@ -69,6 +85,7 @@ class TextToVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=PRICE_BADGE,
)
@classmethod
@ -145,6 +162,7 @@ class ImageToVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=PRICE_BADGE,
)
@classmethod

View File

@ -189,6 +189,19 @@ class LumaImageGenerationNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
expr="""
(
$m := widgets.model;
$contains($m,"photon-flash-1")
? {"type":"usd","usd":0.0027}
: $contains($m,"photon-1")
? {"type":"usd","usd":0.0104}
: {"type":"usd","usd":0.0246}
)
""",
),
)
@classmethod
@ -303,6 +316,19 @@ class LumaImageModifyNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
expr="""
(
$m := widgets.model;
$contains($m,"photon-flash-1")
? {"type":"usd","usd":0.0027}
: $contains($m,"photon-1")
? {"type":"usd","usd":0.0104}
: {"type":"usd","usd":0.0246}
)
""",
),
)
@classmethod
@ -395,6 +421,7 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=PRICE_BADGE_VIDEO,
)
@classmethod
@ -505,6 +532,8 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=PRICE_BADGE_VIDEO,
)
@classmethod
@ -568,6 +597,53 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode):
return LumaKeyframes(frame0=frame0, frame1=frame1)
PRICE_BADGE_VIDEO = IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model", "resolution", "duration"]),
expr="""
(
$p := {
"ray-flash-2": {
"5s": {"4k":3.13,"1080p":0.79,"720p":0.34,"540p":0.2},
"9s": {"4k":5.65,"1080p":1.42,"720p":0.61,"540p":0.36}
},
"ray-2": {
"5s": {"4k":9.11,"1080p":2.27,"720p":1.02,"540p":0.57},
"9s": {"4k":16.4,"1080p":4.1,"720p":1.83,"540p":1.03}
}
};
$m := widgets.model;
$d := widgets.duration;
$r := widgets.resolution;
$modelKey :=
$contains($m,"ray-flash-2") ? "ray-flash-2" :
$contains($m,"ray-2") ? "ray-2" :
$contains($m,"ray-1-6") ? "ray-1-6" :
"other";
$durKey := $contains($d,"5s") ? "5s" : $contains($d,"9s") ? "9s" : "";
$resKey :=
$contains($r,"4k") ? "4k" :
$contains($r,"1080p") ? "1080p" :
$contains($r,"720p") ? "720p" :
$contains($r,"540p") ? "540p" : "";
$modelPrices := $lookup($p, $modelKey);
$durPrices := $lookup($modelPrices, $durKey);
$v := $lookup($durPrices, $resKey);
$price :=
($modelKey = "ray-1-6") ? 0.5 :
($modelKey = "other") ? 0.79 :
($exists($v) ? $v : 0.79);
{"type":"usd","usd": $price}
)
""",
)
class LumaExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:

View File

@ -134,6 +134,9 @@ class MinimaxTextToVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.43}""",
),
)
@classmethod
@ -197,6 +200,9 @@ class MinimaxImageToVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.43}""",
),
)
@classmethod
@ -340,6 +346,20 @@ class MinimaxHailuoVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["resolution", "duration"]),
expr="""
(
$prices := {
"768p": {"6": 0.28, "10": 0.56},
"1080p": {"6": 0.49}
};
$resPrices := $lookup($prices, $lowercase(widgets.resolution));
$price := $lookup($resPrices, $string(widgets.duration));
{"type":"usd","usd": $price ? $price : 0.43}
)
""",
),
)
@classmethod

View File

@ -233,6 +233,10 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(),
expr="""{"type":"usd","usd": 1.5}""",
),
)
@classmethod
@ -351,6 +355,10 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(),
expr="""{"type":"usd","usd": 2.25}""",
),
)
@classmethod
@ -471,6 +479,10 @@ class MoonvalleyTxt2VideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(),
expr="""{"type":"usd","usd": 1.5}""",
),
)
@classmethod

View File

@ -160,6 +160,23 @@ class OpenAIDalle2(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["size", "n"]),
expr="""
(
$size := widgets.size;
$nRaw := widgets.n;
$n := ($nRaw != null and $nRaw != 0) ? $nRaw : 1;
$base :=
$contains($size, "256x256") ? 0.016 :
$contains($size, "512x512") ? 0.018 :
0.02;
{"type":"usd","usd": $round($base * $n, 3)}
)
""",
),
)
@classmethod
@ -287,6 +304,25 @@ class OpenAIDalle3(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["size", "quality"]),
expr="""
(
$size := widgets.size;
$q := widgets.quality;
$hd := $contains($q, "hd");
$price :=
$contains($size, "1024x1024")
? ($hd ? 0.08 : 0.04)
: (($contains($size, "1792x1024") or $contains($size, "1024x1792"))
? ($hd ? 0.12 : 0.08)
: 0.04);
{"type":"usd","usd": $price}
)
""",
),
)
@classmethod
@ -411,6 +447,28 @@ class OpenAIGPTImage1(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["quality", "n"]),
expr="""
(
$ranges := {
"low": [0.011, 0.02],
"medium": [0.046, 0.07],
"high": [0.167, 0.3]
};
$range := $lookup($ranges, widgets.quality);
$n := widgets.n;
($n = 1)
? {"type":"range_usd","min_usd": $range[0], "max_usd": $range[1]}
: {
"type":"range_usd",
"min_usd": $range[0],
"max_usd": $range[1],
"format": { "suffix": " x " & $string($n) & "/Run" }
}
)
""",
),
)
@classmethod
@ -566,6 +624,75 @@ class OpenAIChatNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
expr="""
(
$m := widgets.model;
$contains($m, "o4-mini") ? {
"type": "list_usd",
"usd": [0.0011, 0.0044],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "o1-pro") ? {
"type": "list_usd",
"usd": [0.15, 0.6],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "o1") ? {
"type": "list_usd",
"usd": [0.015, 0.06],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "o3-mini") ? {
"type": "list_usd",
"usd": [0.0011, 0.0044],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "o3") ? {
"type": "list_usd",
"usd": [0.01, 0.04],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "gpt-4o") ? {
"type": "list_usd",
"usd": [0.0025, 0.01],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "gpt-4.1-nano") ? {
"type": "list_usd",
"usd": [0.0001, 0.0004],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "gpt-4.1-mini") ? {
"type": "list_usd",
"usd": [0.0004, 0.0016],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "gpt-4.1") ? {
"type": "list_usd",
"usd": [0.002, 0.008],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "gpt-5-nano") ? {
"type": "list_usd",
"usd": [0.00005, 0.0004],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "gpt-5-mini") ? {
"type": "list_usd",
"usd": [0.00025, 0.002],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "gpt-5") ? {
"type": "list_usd",
"usd": [0.00125, 0.01],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: {"type": "text", "text": "Token-based"}
)
""",
),
)
@classmethod

View File

@ -128,6 +128,7 @@ class PixverseTextToVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=PRICE_BADGE_VIDEO,
)
@classmethod
@ -242,6 +243,7 @@ class PixverseImageToVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=PRICE_BADGE_VIDEO,
)
@classmethod
@ -355,6 +357,7 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=PRICE_BADGE_VIDEO,
)
@classmethod
@ -416,6 +419,33 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
PRICE_BADGE_VIDEO = IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["duration_seconds", "quality", "motion_mode"]),
expr="""
(
$prices := {
"5": {
"1080p": {"normal": 1.2, "fast": 1.2},
"720p": {"normal": 0.6, "fast": 1.2},
"540p": {"normal": 0.45, "fast": 0.9},
"360p": {"normal": 0.45, "fast": 0.9}
},
"8": {
"1080p": {"normal": 1.2, "fast": 1.2},
"720p": {"normal": 1.2, "fast": 1.2},
"540p": {"normal": 0.9, "fast": 1.2},
"360p": {"normal": 0.9, "fast": 1.2}
}
};
$durPrices := $lookup($prices, $string(widgets.duration_seconds));
$qualityPrices := $lookup($durPrices, widgets.quality);
$price := $lookup($qualityPrices, widgets.motion_mode);
{"type":"usd","usd": $price ? $price : 0.9}
)
""",
)
class PixVerseExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:

View File

@ -378,6 +378,10 @@ class RecraftTextToImageNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["n"]),
expr="""{"type":"usd","usd": $round(0.04 * widgets.n, 2)}""",
),
)
@classmethod
@ -490,6 +494,10 @@ class RecraftImageToImageNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["n"]),
expr="""{"type":"usd","usd": $round(0.04 * widgets.n, 2)}""",
),
)
@classmethod
@ -591,6 +599,10 @@ class RecraftImageInpaintingNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["n"]),
expr="""{"type":"usd","usd": $round(0.04 * widgets.n, 2)}""",
),
)
@classmethod
@ -692,6 +704,10 @@ class RecraftTextToVectorNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["n"]),
expr="""{"type":"usd","usd": $round(0.08 * widgets.n, 2)}""",
),
)
@classmethod
@ -759,6 +775,10 @@ class RecraftVectorizeImageNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(),
expr="""{"type":"usd","usd": 0.01}""",
),
)
@classmethod
@ -817,6 +837,9 @@ class RecraftReplaceBackgroundNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.04}""",
),
)
@classmethod
@ -883,6 +906,9 @@ class RecraftRemoveBackgroundNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.01}""",
),
)
@classmethod
@ -929,6 +955,9 @@ class RecraftCrispUpscaleNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.004}""",
),
)
@classmethod
@ -972,6 +1001,9 @@ class RecraftCreativeUpscaleNode(RecraftCrispUpscaleNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.25}""",
),
)

View File

@ -241,6 +241,9 @@ class Rodin3D_Regular(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.4}""",
),
)
@classmethod
@ -294,6 +297,9 @@ class Rodin3D_Detail(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.4}""",
),
)
@classmethod
@ -347,6 +353,9 @@ class Rodin3D_Smooth(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.4}""",
),
)
@classmethod
@ -406,6 +415,9 @@ class Rodin3D_Sketch(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.4}""",
),
)
@classmethod

View File

@ -184,6 +184,10 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["duration"]),
expr="""{"type":"usd","usd": 0.0715 * widgets.duration}""",
),
)
@classmethod
@ -274,6 +278,10 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["duration"]),
expr="""{"type":"usd","usd": 0.0715 * widgets.duration}""",
),
)
@classmethod
@ -372,6 +380,10 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["duration"]),
expr="""{"type":"usd","usd": 0.0715 * widgets.duration}""",
),
)
@classmethod
@ -457,6 +469,9 @@ class RunwayTextToImageNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.11}""",
),
)
@classmethod

View File

@ -89,6 +89,24 @@ class OpenAIVideoSora2(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model", "size", "duration"]),
expr="""
(
$m := widgets.model;
$size := widgets.size;
$dur := widgets.duration;
$isPro := $contains($m, "sora-2-pro");
$isSora2 := $contains($m, "sora-2");
$isProSize := ($size = "1024x1792" or $size = "1792x1024");
$perSec :=
$isPro ? ($isProSize ? 0.5 : 0.3) :
$isSora2 ? 0.1 :
($isProSize ? 0.5 : 0.1);
{"type":"usd","usd": $round($perSec * $dur, 2)}
)
""",
),
)
@classmethod

View File

@ -127,6 +127,9 @@ class StabilityStableImageUltraNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.08}""",
),
)
@classmethod
@ -264,6 +267,16 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
expr="""
(
$contains(widgets.model,"large")
? {"type":"usd","usd":0.065}
: {"type":"usd","usd":0.035}
)
""",
),
)
@classmethod
@ -382,6 +395,9 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.25}""",
),
)
@classmethod
@ -486,6 +502,9 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.25}""",
),
)
@classmethod
@ -566,6 +585,9 @@ class StabilityUpscaleFastNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.01}""",
),
)
@classmethod
@ -648,6 +670,9 @@ class StabilityTextToAudio(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.2}""",
),
)
@classmethod
@ -732,6 +757,9 @@ class StabilityAudioToAudio(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.2}""",
),
)
@classmethod
@ -828,6 +856,9 @@ class StabilityAudioInpaint(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.2}""",
),
)
@classmethod

View File

@ -2,7 +2,6 @@ import builtins
from io import BytesIO
import aiohttp
import torch
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
@ -138,7 +137,7 @@ class TopazImageEnhance(IO.ComfyNode):
async def execute(
cls,
model: str,
image: torch.Tensor,
image: Input.Image,
prompt: str = "",
subject_detection: str = "All",
face_enhancement: bool = True,
@ -153,7 +152,9 @@ class TopazImageEnhance(IO.ComfyNode):
) -> IO.NodeOutput:
if get_number_of_images(image) != 1:
raise ValueError("Only one input image is supported.")
download_url = await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png")
download_url = await upload_images_to_comfyapi(
cls, image, max_images=1, mime_type="image/png", total_pixels=4096*4096
)
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/topaz/image/v1/enhance-gen/async", method="POST"),

View File

@ -117,6 +117,38 @@ class TripoTextToModelNode(IO.ComfyNode):
],
is_api_node=True,
is_output_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=[
"model_version",
"style",
"texture",
"pbr",
"quad",
"texture_quality",
"geometry_quality",
],
),
expr="""
(
$isV14 := $contains(widgets.model_version,"v1.4");
$style := widgets.style;
$hasStyle := ($style != "" and $style != "none");
$withTexture := widgets.texture or widgets.pbr;
$isHdTexture := (widgets.texture_quality = "detailed");
$isDetailedGeometry := (widgets.geometry_quality = "detailed");
$baseCredits :=
$isV14 ? 20 : ($withTexture ? 20 : 10);
$credits :=
$baseCredits
+ ($hasStyle ? 5 : 0)
+ (widgets.quad ? 5 : 0)
+ ($isHdTexture ? 10 : 0)
+ ($isDetailedGeometry ? 20 : 0);
{"type":"usd","usd": $round($credits * 0.01, 2)}
)
""",
),
)
@classmethod
@ -155,7 +187,7 @@ class TripoTextToModelNode(IO.ComfyNode):
model_seed=model_seed,
texture_seed=texture_seed,
texture_quality=texture_quality,
face_limit=face_limit,
face_limit=face_limit if face_limit != -1 else None,
geometry_quality=geometry_quality,
auto_size=True,
quad=quad,
@ -210,6 +242,38 @@ class TripoImageToModelNode(IO.ComfyNode):
],
is_api_node=True,
is_output_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=[
"model_version",
"style",
"texture",
"pbr",
"quad",
"texture_quality",
"geometry_quality",
],
),
expr="""
(
$isV14 := $contains(widgets.model_version,"v1.4");
$style := widgets.style;
$hasStyle := ($style != "" and $style != "none");
$withTexture := widgets.texture or widgets.pbr;
$isHdTexture := (widgets.texture_quality = "detailed");
$isDetailedGeometry := (widgets.geometry_quality = "detailed");
$baseCredits :=
$isV14 ? 30 : ($withTexture ? 30 : 20);
$credits :=
$baseCredits
+ ($hasStyle ? 5 : 0)
+ (widgets.quad ? 5 : 0)
+ ($isHdTexture ? 10 : 0)
+ ($isDetailedGeometry ? 20 : 0);
{"type":"usd","usd": $round($credits * 0.01, 2)}
)
""",
),
)
@classmethod
@ -255,7 +319,7 @@ class TripoImageToModelNode(IO.ComfyNode):
texture_alignment=texture_alignment,
texture_seed=texture_seed,
texture_quality=texture_quality,
face_limit=face_limit,
face_limit=face_limit if face_limit != -1 else None,
auto_size=True,
quad=quad,
),
@ -314,6 +378,34 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
],
is_api_node=True,
is_output_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=[
"model_version",
"texture",
"pbr",
"quad",
"texture_quality",
"geometry_quality",
],
),
expr="""
(
$isV14 := $contains(widgets.model_version,"v1.4");
$withTexture := widgets.texture or widgets.pbr;
$isHdTexture := (widgets.texture_quality = "detailed");
$isDetailedGeometry := (widgets.geometry_quality = "detailed");
$baseCredits :=
$isV14 ? 30 : ($withTexture ? 30 : 20);
$credits :=
$baseCredits
+ (widgets.quad ? 5 : 0)
+ ($isHdTexture ? 10 : 0)
+ ($isDetailedGeometry ? 20 : 0);
{"type":"usd","usd": $round($credits * 0.01, 2)}
)
""",
),
)
@classmethod
@ -369,7 +461,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
texture_quality=texture_quality,
geometry_quality=geometry_quality,
texture_alignment=texture_alignment,
face_limit=face_limit,
face_limit=face_limit if face_limit != -1 else None,
quad=quad,
),
)
@ -405,6 +497,15 @@ class TripoTextureNode(IO.ComfyNode):
],
is_api_node=True,
is_output_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["texture_quality"]),
expr="""
(
$tq := widgets.texture_quality;
{"type":"usd","usd": ($contains($tq,"detailed") ? 0.2 : 0.1)}
)
""",
),
)
@classmethod
@ -456,6 +557,9 @@ class TripoRefineNode(IO.ComfyNode):
],
is_api_node=True,
is_output_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.3}""",
),
)
@classmethod
@ -489,6 +593,9 @@ class TripoRigNode(IO.ComfyNode):
],
is_api_node=True,
is_output_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.25}""",
),
)
@classmethod
@ -545,6 +652,9 @@ class TripoRetargetNode(IO.ComfyNode):
],
is_api_node=True,
is_output_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.1}""",
),
)
@classmethod
@ -638,6 +748,60 @@ class TripoConversionNode(IO.ComfyNode):
],
is_api_node=True,
is_output_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=[
"quad",
"face_limit",
"texture_size",
"texture_format",
"force_symmetry",
"flatten_bottom",
"flatten_bottom_threshold",
"pivot_to_center_bottom",
"scale_factor",
"with_animation",
"pack_uv",
"bake",
"part_names",
"fbx_preset",
"export_vertex_colors",
"export_orientation",
"animate_in_place",
],
),
expr="""
(
$face := (widgets.face_limit != null) ? widgets.face_limit : -1;
$texSize := (widgets.texture_size != null) ? widgets.texture_size : 4096;
$flatThresh := (widgets.flatten_bottom_threshold != null) ? widgets.flatten_bottom_threshold : 0;
$scale := (widgets.scale_factor != null) ? widgets.scale_factor : 1;
$texFmt := (widgets.texture_format != "" ? widgets.texture_format : "jpeg");
$part := widgets.part_names;
$fbx := (widgets.fbx_preset != "" ? widgets.fbx_preset : "blender");
$orient := (widgets.export_orientation != "" ? widgets.export_orientation : "default");
$advanced :=
widgets.quad or
widgets.force_symmetry or
widgets.flatten_bottom or
widgets.pivot_to_center_bottom or
widgets.with_animation or
widgets.pack_uv or
widgets.bake or
widgets.export_vertex_colors or
widgets.animate_in_place or
($face != -1) or
($texSize != 4096) or
($flatThresh != 0) or
($scale != 1) or
($texFmt != "jpeg") or
($part != "") or
($fbx != "blender") or
($orient != "default");
{"type":"usd","usd": ($advanced ? 0.1 : 0.05)}
)
""",
),
)
@classmethod

View File

@ -122,6 +122,10 @@ class VeoVideoGenerationNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["duration_seconds"]),
expr="""{"type":"usd","usd": 0.5 * widgets.duration_seconds}""",
),
)
@classmethod
@ -347,6 +351,20 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio"]),
expr="""
(
$m := widgets.model;
$a := widgets.generate_audio;
($contains($m,"veo-3.0-fast-generate-001") or $contains($m,"veo-3.1-fast-generate"))
? {"type":"usd","usd": ($a ? 1.2 : 0.8)}
: ($contains($m,"veo-3.0-generate-001") or $contains($m,"veo-3.1-generate"))
? {"type":"usd","usd": ($a ? 3.2 : 1.6)}
: {"type":"range_usd","min_usd":0.8,"max_usd":3.2}
)
""",
),
)
@ -420,6 +438,30 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio", "duration"]),
expr="""
(
$prices := {
"veo-3.1-fast-generate": { "audio": 0.15, "no_audio": 0.10 },
"veo-3.1-generate": { "audio": 0.40, "no_audio": 0.20 }
};
$m := widgets.model;
$ga := (widgets.generate_audio = "true");
$seconds := widgets.duration;
$modelKey :=
$contains($m, "veo-3.1-fast-generate") ? "veo-3.1-fast-generate" :
$contains($m, "veo-3.1-generate") ? "veo-3.1-generate" :
"";
$audioKey := $ga ? "audio" : "no_audio";
$modelPrices := $lookup($prices, $modelKey);
$pps := $lookup($modelPrices, $audioKey);
($pps != null)
? {"type":"usd","usd": $pps * $seconds}
: {"type":"range_usd","min_usd": 0.4, "max_usd": 3.2}
)
""",
),
)
@classmethod

View File

@ -1,12 +1,13 @@
import logging
from enum import Enum
from typing import Literal, Optional, TypeVar
import torch
from pydantic import BaseModel, Field
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.vidu import (
SubjectReference,
TaskCreationRequest,
TaskCreationResponse,
TaskResult,
TaskStatusResponse,
)
from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_video_output,
@ -17,6 +18,7 @@ from comfy_api_nodes.util import (
validate_image_aspect_ratio,
validate_image_dimensions,
validate_images_aspect_ratio_closeness,
validate_string,
)
VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video"
@ -25,98 +27,33 @@ VIDU_REFERENCE_VIDEO = "/proxy/vidu/reference2video"
VIDU_START_END_VIDEO = "/proxy/vidu/start-end2video"
VIDU_GET_GENERATION_STATUS = "/proxy/vidu/tasks/%s/creations"
R = TypeVar("R")
class VideoModelName(str, Enum):
vidu_q1 = "viduq1"
class AspectRatio(str, Enum):
r_16_9 = "16:9"
r_9_16 = "9:16"
r_1_1 = "1:1"
class Resolution(str, Enum):
r_1080p = "1080p"
class MovementAmplitude(str, Enum):
auto = "auto"
small = "small"
medium = "medium"
large = "large"
class TaskCreationRequest(BaseModel):
model: VideoModelName = VideoModelName.vidu_q1
prompt: Optional[str] = Field(None, max_length=1500)
duration: Optional[Literal[5]] = 5
seed: Optional[int] = Field(0, ge=0, le=2147483647)
aspect_ratio: Optional[AspectRatio] = AspectRatio.r_16_9
resolution: Optional[Resolution] = Resolution.r_1080p
movement_amplitude: Optional[MovementAmplitude] = MovementAmplitude.auto
images: Optional[list[str]] = Field(None, description="Base64 encoded string or image URL")
class TaskCreationResponse(BaseModel):
task_id: str = Field(...)
state: str = Field(...)
created_at: str = Field(...)
code: Optional[int] = Field(None, description="Error code")
class TaskResult(BaseModel):
id: str = Field(..., description="Creation id")
url: str = Field(..., description="The URL of the generated results, valid for one hour")
cover_url: str = Field(..., description="The cover URL of the generated results, valid for one hour")
class TaskStatusResponse(BaseModel):
state: str = Field(...)
err_code: Optional[str] = Field(None)
creations: list[TaskResult] = Field(..., description="Generated results")
def get_video_url_from_response(response) -> Optional[str]:
if response.creations:
return response.creations[0].url
return None
def get_video_from_response(response) -> TaskResult:
if not response.creations:
error_msg = f"Vidu request does not contain results. State: {response.state}, Error Code: {response.err_code}"
logging.info(error_msg)
raise RuntimeError(error_msg)
logging.info("Vidu task %s succeeded. Video URL: %s", response.creations[0].id, response.creations[0].url)
return response.creations[0]
async def execute_task(
cls: type[IO.ComfyNode],
vidu_endpoint: str,
payload: TaskCreationRequest,
estimated_duration: int,
) -> R:
response = await sync_op(
) -> list[TaskResult]:
task_creation_response = await sync_op(
cls,
endpoint=ApiEndpoint(path=vidu_endpoint, method="POST"),
response_model=TaskCreationResponse,
data=payload,
)
if response.state == "failed":
error_msg = f"Vidu request failed. Code: {response.code}"
logging.error(error_msg)
raise RuntimeError(error_msg)
return await poll_op(
if task_creation_response.state == "failed":
raise RuntimeError(f"Vidu request failed. Code: {task_creation_response.code}")
response = await poll_op(
cls,
ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % response.task_id),
ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % task_creation_response.task_id),
response_model=TaskStatusResponse,
status_extractor=lambda r: r.state,
estimated_duration=estimated_duration,
progress_extractor=lambda r: r.progress,
max_poll_attempts=320,
)
if not response.creations:
raise RuntimeError(
f"Vidu request does not contain results. State: {response.state}, Error Code: {response.err_code}"
)
return response.creations
class ViduTextToVideoNode(IO.ComfyNode):
@ -127,14 +64,9 @@ class ViduTextToVideoNode(IO.ComfyNode):
node_id="ViduTextToVideoNode",
display_name="Vidu Text To Video Generation",
category="api node/video/Vidu",
description="Generate video from text prompt",
description="Generate video from a text prompt",
inputs=[
IO.Combo.Input(
"model",
options=VideoModelName,
default=VideoModelName.vidu_q1,
tooltip="Model name",
),
IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
IO.String.Input(
"prompt",
multiline=True,
@ -163,22 +95,19 @@ class ViduTextToVideoNode(IO.ComfyNode):
),
IO.Combo.Input(
"aspect_ratio",
options=AspectRatio,
default=AspectRatio.r_16_9,
options=["16:9", "9:16", "1:1"],
tooltip="The aspect ratio of the output video",
optional=True,
),
IO.Combo.Input(
"resolution",
options=Resolution,
default=Resolution.r_1080p,
options=["1080p"],
tooltip="Supported values may vary by model & duration",
optional=True,
),
IO.Combo.Input(
"movement_amplitude",
options=MovementAmplitude,
default=MovementAmplitude.auto,
options=["auto", "small", "medium", "large"],
tooltip="The movement amplitude of objects in the frame",
optional=True,
),
@ -192,6 +121,9 @@ class ViduTextToVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.4}""",
),
)
@classmethod
@ -208,7 +140,7 @@ class ViduTextToVideoNode(IO.ComfyNode):
if not prompt:
raise ValueError("The prompt field is required and cannot be empty.")
payload = TaskCreationRequest(
model_name=model,
model=model,
prompt=prompt,
duration=duration,
seed=seed,
@ -216,8 +148,8 @@ class ViduTextToVideoNode(IO.ComfyNode):
resolution=resolution,
movement_amplitude=movement_amplitude,
)
results = await execute_task(cls, VIDU_TEXT_TO_VIDEO, payload, 320)
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
results = await execute_task(cls, VIDU_TEXT_TO_VIDEO, payload)
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
class ViduImageToVideoNode(IO.ComfyNode):
@ -230,12 +162,7 @@ class ViduImageToVideoNode(IO.ComfyNode):
category="api node/video/Vidu",
description="Generate video from image and optional prompt",
inputs=[
IO.Combo.Input(
"model",
options=VideoModelName,
default=VideoModelName.vidu_q1,
tooltip="Model name",
),
IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
IO.Image.Input(
"image",
tooltip="An image to be used as the start frame of the generated video",
@ -270,15 +197,13 @@ class ViduImageToVideoNode(IO.ComfyNode):
),
IO.Combo.Input(
"resolution",
options=Resolution,
default=Resolution.r_1080p,
options=["1080p"],
tooltip="Supported values may vary by model & duration",
optional=True,
),
IO.Combo.Input(
"movement_amplitude",
options=MovementAmplitude,
default=MovementAmplitude.auto.value,
options=["auto", "small", "medium", "large"],
tooltip="The movement amplitude of objects in the frame",
optional=True,
),
@ -292,13 +217,16 @@ class ViduImageToVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.4}""",
),
)
@classmethod
async def execute(
cls,
model: str,
image: torch.Tensor,
image: Input.Image,
prompt: str,
duration: int,
seed: int,
@ -309,7 +237,7 @@ class ViduImageToVideoNode(IO.ComfyNode):
raise ValueError("Only one input image is allowed.")
validate_image_aspect_ratio(image, (1, 4), (4, 1))
payload = TaskCreationRequest(
model_name=model,
model=model,
prompt=prompt,
duration=duration,
seed=seed,
@ -322,8 +250,8 @@ class ViduImageToVideoNode(IO.ComfyNode):
max_images=1,
mime_type="image/png",
)
results = await execute_task(cls, VIDU_IMAGE_TO_VIDEO, payload, 120)
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
results = await execute_task(cls, VIDU_IMAGE_TO_VIDEO, payload)
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
class ViduReferenceVideoNode(IO.ComfyNode):
@ -334,14 +262,9 @@ class ViduReferenceVideoNode(IO.ComfyNode):
node_id="ViduReferenceVideoNode",
display_name="Vidu Reference To Video Generation",
category="api node/video/Vidu",
description="Generate video from multiple images and prompt",
description="Generate video from multiple images and a prompt",
inputs=[
IO.Combo.Input(
"model",
options=VideoModelName,
default=VideoModelName.vidu_q1,
tooltip="Model name",
),
IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
IO.Image.Input(
"images",
tooltip="Images to use as references to generate a video with consistent subjects (max 7 images).",
@ -374,22 +297,19 @@ class ViduReferenceVideoNode(IO.ComfyNode):
),
IO.Combo.Input(
"aspect_ratio",
options=AspectRatio,
default=AspectRatio.r_16_9,
options=["16:9", "9:16", "1:1"],
tooltip="The aspect ratio of the output video",
optional=True,
),
IO.Combo.Input(
"resolution",
options=[model.value for model in Resolution],
default=Resolution.r_1080p.value,
options=["1080p"],
tooltip="Supported values may vary by model & duration",
optional=True,
),
IO.Combo.Input(
"movement_amplitude",
options=[model.value for model in MovementAmplitude],
default=MovementAmplitude.auto.value,
options=["auto", "small", "medium", "large"],
tooltip="The movement amplitude of objects in the frame",
optional=True,
),
@ -403,13 +323,16 @@ class ViduReferenceVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.4}""",
),
)
@classmethod
async def execute(
cls,
model: str,
images: torch.Tensor,
images: Input.Image,
prompt: str,
duration: int,
seed: int,
@ -426,7 +349,7 @@ class ViduReferenceVideoNode(IO.ComfyNode):
validate_image_aspect_ratio(image, (1, 4), (4, 1))
validate_image_dimensions(image, min_width=128, min_height=128)
payload = TaskCreationRequest(
model_name=model,
model=model,
prompt=prompt,
duration=duration,
seed=seed,
@ -440,8 +363,8 @@ class ViduReferenceVideoNode(IO.ComfyNode):
max_images=7,
mime_type="image/png",
)
results = await execute_task(cls, VIDU_REFERENCE_VIDEO, payload, 120)
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
results = await execute_task(cls, VIDU_REFERENCE_VIDEO, payload)
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
class ViduStartEndToVideoNode(IO.ComfyNode):
@ -454,12 +377,7 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
category="api node/video/Vidu",
description="Generate a video from start and end frames and a prompt",
inputs=[
IO.Combo.Input(
"model",
options=[model.value for model in VideoModelName],
default=VideoModelName.vidu_q1.value,
tooltip="Model name",
),
IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
IO.Image.Input(
"first_frame",
tooltip="Start frame",
@ -497,15 +415,13 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
),
IO.Combo.Input(
"resolution",
options=[model.value for model in Resolution],
default=Resolution.r_1080p.value,
options=["1080p"],
tooltip="Supported values may vary by model & duration",
optional=True,
),
IO.Combo.Input(
"movement_amplitude",
options=[model.value for model in MovementAmplitude],
default=MovementAmplitude.auto.value,
options=["auto", "small", "medium", "large"],
tooltip="The movement amplitude of objects in the frame",
optional=True,
),
@ -519,14 +435,17 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.4}""",
),
)
@classmethod
async def execute(
cls,
model: str,
first_frame: torch.Tensor,
end_frame: torch.Tensor,
first_frame: Input.Image,
end_frame: Input.Image,
prompt: str,
duration: int,
seed: int,
@ -535,7 +454,7 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
) -> IO.NodeOutput:
validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
payload = TaskCreationRequest(
model_name=model,
model=model,
prompt=prompt,
duration=duration,
seed=seed,
@ -546,8 +465,479 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
(await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0]
for frame in (first_frame, end_frame)
]
results = await execute_task(cls, VIDU_START_END_VIDEO, payload, 96)
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
results = await execute_task(cls, VIDU_START_END_VIDEO, payload)
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
class Vidu2TextToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="Vidu2TextToVideoNode",
display_name="Vidu2 Text-to-Video Generation",
category="api node/video/Vidu",
description="Generate video from a text prompt",
inputs=[
IO.Combo.Input("model", options=["viduq2"]),
IO.String.Input(
"prompt",
multiline=True,
tooltip="A textual description for video generation, with a maximum length of 2000 characters.",
),
IO.Int.Input(
"duration",
default=5,
min=1,
max=10,
step=1,
display_mode=IO.NumberDisplay.slider,
),
IO.Int.Input(
"seed",
default=1,
min=0,
max=2147483647,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
),
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "3:4", "4:3", "1:1"]),
IO.Combo.Input("resolution", options=["720p", "1080p"]),
IO.Boolean.Input(
"background_music",
default=False,
tooltip="Whether to add background music to the generated video.",
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]),
expr="""
(
$is1080 := widgets.resolution = "1080p";
$base := $is1080 ? 0.1 : 0.075;
$perSec := $is1080 ? 0.05 : 0.025;
{"type":"usd","usd": $base + $perSec * (widgets.duration - 1)}
)
""",
),
)
@classmethod
async def execute(
cls,
model: str,
prompt: str,
duration: int,
seed: int,
aspect_ratio: str,
resolution: str,
background_music: bool,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2000)
results = await execute_task(
cls,
VIDU_TEXT_TO_VIDEO,
TaskCreationRequest(
model=model,
prompt=prompt,
duration=duration,
seed=seed,
aspect_ratio=aspect_ratio,
resolution=resolution,
bgm=background_music,
),
)
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
class Vidu2ImageToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="Vidu2ImageToVideoNode",
display_name="Vidu2 Image-to-Video Generation",
category="api node/video/Vidu",
description="Generate a video from an image and an optional prompt.",
inputs=[
IO.Combo.Input("model", options=["viduq2-pro-fast", "viduq2-pro", "viduq2-turbo"]),
IO.Image.Input(
"image",
tooltip="An image to be used as the start frame of the generated video.",
),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="An optional text prompt for video generation (max 2000 characters).",
),
IO.Int.Input(
"duration",
default=5,
min=1,
max=10,
step=1,
display_mode=IO.NumberDisplay.slider,
),
IO.Int.Input(
"seed",
default=1,
min=0,
max=2147483647,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
),
IO.Combo.Input(
"resolution",
options=["720p", "1080p"],
),
IO.Combo.Input(
"movement_amplitude",
options=["auto", "small", "medium", "large"],
tooltip="The movement amplitude of objects in the frame.",
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"]),
expr="""
(
$m := widgets.model;
$d := widgets.duration;
$is1080 := widgets.resolution = "1080p";
$contains($m, "pro-fast")
? (
$base := $is1080 ? 0.08 : 0.04;
$perSec := $is1080 ? 0.02 : 0.01;
{"type":"usd","usd": $base + $perSec * ($d - 1)}
)
: $contains($m, "pro")
? (
$base := $is1080 ? 0.275 : 0.075;
$perSec := $is1080 ? 0.075 : 0.05;
{"type":"usd","usd": $base + $perSec * ($d - 1)}
)
: $contains($m, "turbo")
? (
$is1080
? {"type":"usd","usd": 0.175 + 0.05 * ($d - 1)}
: (
$d <= 1 ? {"type":"usd","usd": 0.04}
: $d <= 2 ? {"type":"usd","usd": 0.05}
: {"type":"usd","usd": 0.05 + 0.05 * ($d - 2)}
)
)
: {"type":"usd","usd": 0.04}
)
""",
),
)
@classmethod
async def execute(
cls,
model: str,
image: Input.Image,
prompt: str,
duration: int,
seed: int,
resolution: str,
movement_amplitude: str,
) -> IO.NodeOutput:
if get_number_of_images(image) > 1:
raise ValueError("Only one input image is allowed.")
validate_image_aspect_ratio(image, (1, 4), (4, 1))
validate_string(prompt, max_length=2000)
results = await execute_task(
cls,
VIDU_IMAGE_TO_VIDEO,
TaskCreationRequest(
model=model,
prompt=prompt,
duration=duration,
seed=seed,
resolution=resolution,
movement_amplitude=movement_amplitude,
images=await upload_images_to_comfyapi(
cls,
image,
max_images=1,
mime_type="image/png",
),
),
)
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
class Vidu2ReferenceVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="Vidu2ReferenceVideoNode",
display_name="Vidu2 Reference-to-Video Generation",
category="api node/video/Vidu",
description="Generate a video from multiple reference images and a prompt.",
inputs=[
IO.Combo.Input("model", options=["viduq2"]),
IO.Autogrow.Input(
"subjects",
template=IO.Autogrow.TemplateNames(
IO.Image.Input("reference_images"),
names=["subject1", "subject2", "subject3"],
min=1,
),
tooltip="For each subject, provide up to 3 reference images (7 images total across all subjects). "
"Reference them in prompts via @subject{subject_id}.",
),
IO.String.Input(
"prompt",
multiline=True,
tooltip="When enabled, the video will include generated speech and background music "
"based on the prompt.",
),
IO.Boolean.Input(
"audio",
default=False,
tooltip="When enabled video will contain generated speech and background music based on the prompt.",
),
IO.Int.Input(
"duration",
default=5,
min=1,
max=10,
step=1,
display_mode=IO.NumberDisplay.slider,
),
IO.Int.Input(
"seed",
default=1,
min=0,
max=2147483647,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
),
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "4:3", "3:4", "1:1"]),
IO.Combo.Input("resolution", options=["720p"]),
IO.Combo.Input(
"movement_amplitude",
options=["auto", "small", "medium", "large"],
tooltip="The movement amplitude of objects in the frame.",
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["audio", "duration", "resolution"]),
expr="""
(
$is1080 := widgets.resolution = "1080p";
$base := $is1080 ? 0.375 : 0.125;
$perSec := $is1080 ? 0.05 : 0.025;
$audioCost := widgets.audio = true ? 0.075 : 0;
{"type":"usd","usd": $base + $perSec * (widgets.duration - 1) + $audioCost}
)
""",
),
)
@classmethod
async def execute(
cls,
model: str,
subjects: IO.Autogrow.Type,
prompt: str,
audio: bool,
duration: int,
seed: int,
aspect_ratio: str,
resolution: str,
movement_amplitude: str,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2000)
total_images = 0
for i in subjects:
if get_number_of_images(subjects[i]) > 3:
raise ValueError("Maximum number of images per subject is 3.")
for im in subjects[i]:
total_images += 1
validate_image_aspect_ratio(im, (1, 4), (4, 1))
validate_image_dimensions(im, min_width=128, min_height=128)
if total_images > 7:
raise ValueError("Too many reference images; the maximum allowed is 7.")
subjects_param: list[SubjectReference] = []
for i in subjects:
subjects_param.append(
SubjectReference(
id=i,
images=await upload_images_to_comfyapi(
cls,
subjects[i],
max_images=3,
mime_type="image/png",
wait_label=f"Uploading reference images for {i}",
),
),
)
payload = TaskCreationRequest(
model=model,
prompt=prompt,
audio=audio,
duration=duration,
seed=seed,
aspect_ratio=aspect_ratio,
resolution=resolution,
movement_amplitude=movement_amplitude,
subjects=subjects_param,
)
results = await execute_task(cls, VIDU_REFERENCE_VIDEO, payload)
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
class Vidu2StartEndToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="Vidu2StartEndToVideoNode",
display_name="Vidu2 Start/End Frame-to-Video Generation",
category="api node/video/Vidu",
description="Generate a video from a start frame, an end frame, and a prompt.",
inputs=[
IO.Combo.Input("model", options=["viduq2-pro-fast", "viduq2-pro", "viduq2-turbo"]),
IO.Image.Input("first_frame"),
IO.Image.Input("end_frame"),
IO.String.Input(
"prompt",
multiline=True,
tooltip="Prompt description (max 2000 characters).",
),
IO.Int.Input(
"duration",
default=5,
min=2,
max=8,
step=1,
display_mode=IO.NumberDisplay.slider,
),
IO.Int.Input(
"seed",
default=1,
min=0,
max=2147483647,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
),
IO.Combo.Input("resolution", options=["720p", "1080p"]),
IO.Combo.Input(
"movement_amplitude",
options=["auto", "small", "medium", "large"],
tooltip="The movement amplitude of objects in the frame.",
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"]),
expr="""
(
$m := widgets.model;
$d := widgets.duration;
$is1080 := widgets.resolution = "1080p";
$contains($m, "pro-fast")
? (
$base := $is1080 ? 0.08 : 0.04;
$perSec := $is1080 ? 0.02 : 0.01;
{"type":"usd","usd": $base + $perSec * ($d - 1)}
)
: $contains($m, "pro")
? (
$base := $is1080 ? 0.275 : 0.075;
$perSec := $is1080 ? 0.075 : 0.05;
{"type":"usd","usd": $base + $perSec * ($d - 1)}
)
: $contains($m, "turbo")
? (
$is1080
? {"type":"usd","usd": 0.175 + 0.05 * ($d - 1)}
: (
$d <= 2 ? {"type":"usd","usd": 0.05}
: {"type":"usd","usd": 0.05 + 0.05 * ($d - 2)}
)
)
: {"type":"usd","usd": 0.04}
)
""",
),
)
@classmethod
async def execute(
cls,
model: str,
first_frame: Input.Image,
end_frame: Input.Image,
prompt: str,
duration: int,
seed: int,
resolution: str,
movement_amplitude: str,
) -> IO.NodeOutput:
validate_string(prompt, max_length=2000)
if get_number_of_images(first_frame) > 1:
raise ValueError("Only one input image is allowed for `first_frame`.")
if get_number_of_images(end_frame) > 1:
raise ValueError("Only one input image is allowed for `end_frame`.")
validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
payload = TaskCreationRequest(
model=model,
prompt=prompt,
duration=duration,
seed=seed,
resolution=resolution,
movement_amplitude=movement_amplitude,
images=[
(await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0]
for frame in (first_frame, end_frame)
],
)
results = await execute_task(cls, VIDU_START_END_VIDEO, payload)
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
class ViduExtension(ComfyExtension):
@ -558,6 +948,10 @@ class ViduExtension(ComfyExtension):
ViduImageToVideoNode,
ViduReferenceVideoNode,
ViduStartEndToVideoNode,
Vidu2TextToVideoNode,
Vidu2ImageToVideoNode,
Vidu2ReferenceVideoNode,
Vidu2StartEndToVideoNode,
]

View File

@ -13,7 +13,9 @@ from comfy_api_nodes.util import (
poll_op,
sync_op,
tensor_to_base64_string,
upload_video_to_comfyapi,
validate_audio_duration,
validate_video_duration,
)
@ -41,6 +43,12 @@ class Image2VideoInputField(BaseModel):
audio_url: str | None = Field(None)
class Reference2VideoInputField(BaseModel):
prompt: str = Field(...)
negative_prompt: str | None = Field(None)
reference_video_urls: list[str] = Field(...)
class Txt2ImageParametersField(BaseModel):
size: str = Field(...)
n: int = Field(1, description="Number of images to generate.") # we support only value=1
@ -76,6 +84,14 @@ class Image2VideoParametersField(BaseModel):
shot_type: str = Field("single")
class Reference2VideoParametersField(BaseModel):
size: str = Field(...)
duration: int = Field(5, ge=5, le=15)
shot_type: str = Field("single")
seed: int = Field(..., ge=0, le=2147483647)
watermark: bool = Field(False)
class Text2ImageTaskCreationRequest(BaseModel):
model: str = Field(...)
input: Text2ImageInputField = Field(...)
@ -100,6 +116,12 @@ class Image2VideoTaskCreationRequest(BaseModel):
parameters: Image2VideoParametersField = Field(...)
class Reference2VideoTaskCreationRequest(BaseModel):
model: str = Field(...)
input: Reference2VideoInputField = Field(...)
parameters: Reference2VideoParametersField = Field(...)
class TaskCreationOutputField(BaseModel):
task_id: str = Field(...)
task_status: str = Field(...)
@ -222,6 +244,9 @@ class WanTextToImageApi(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.03}""",
),
)
@classmethod
@ -341,6 +366,9 @@ class WanImageToImageApi(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.03}""",
),
)
@classmethod
@ -498,6 +526,17 @@ class WanTextToVideoApi(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["duration", "size"]),
expr="""
(
$ppsTable := { "480p": 0.05, "720p": 0.1, "1080p": 0.15 };
$resKey := $substringBefore(widgets.size, ":");
$pps := $lookup($ppsTable, $resKey);
{ "type": "usd", "usd": $round($pps * widgets.duration, 2) }
)
""",
),
)
@classmethod
@ -659,6 +698,16 @@ class WanImageToVideoApi(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]),
expr="""
(
$ppsTable := { "480p": 0.05, "720p": 0.1, "1080p": 0.15 };
$pps := $lookup($ppsTable, widgets.resolution);
{ "type": "usd", "usd": $round($pps * widgets.duration, 2) }
)
""",
),
)
@classmethod
@ -721,6 +770,159 @@ class WanImageToVideoApi(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_video_output(response.output.video_url))
class WanReferenceVideoApi(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="WanReferenceVideoApi",
display_name="Wan Reference to Video",
category="api node/video/Wan",
description="Use the character and voice from input videos, combined with a prompt, "
"to generate a new video that maintains character consistency.",
inputs=[
IO.Combo.Input("model", options=["wan2.6-r2v"]),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt describing the elements and visual features. Supports English and Chinese. "
"Use identifiers such as `character1` and `character2` to refer to the reference characters.",
),
IO.String.Input(
"negative_prompt",
multiline=True,
default="",
tooltip="Negative prompt describing what to avoid.",
),
IO.Autogrow.Input(
"reference_videos",
template=IO.Autogrow.TemplateNames(
IO.Video.Input("reference_video"),
names=["character1", "character2", "character3"],
min=1,
),
),
IO.Combo.Input(
"size",
options=[
"720p: 1:1 (960x960)",
"720p: 16:9 (1280x720)",
"720p: 9:16 (720x1280)",
"720p: 4:3 (1088x832)",
"720p: 3:4 (832x1088)",
"1080p: 1:1 (1440x1440)",
"1080p: 16:9 (1920x1080)",
"1080p: 9:16 (1080x1920)",
"1080p: 4:3 (1632x1248)",
"1080p: 3:4 (1248x1632)",
],
),
IO.Int.Input(
"duration",
default=5,
min=5,
max=10,
step=5,
display_mode=IO.NumberDisplay.slider,
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
),
IO.Combo.Input(
"shot_type",
options=["single", "multi"],
tooltip="Specifies the shot type for the generated video, that is, whether the video is a "
"single continuous shot or multiple shots with cuts.",
),
IO.Boolean.Input(
"watermark",
default=False,
tooltip="Whether to add an AI-generated watermark to the result.",
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["size", "duration"]),
expr="""
(
$rate := $contains(widgets.size, "1080p") ? 0.15 : 0.10;
$inputMin := 2 * $rate;
$inputMax := 5 * $rate;
$outputPrice := widgets.duration * $rate;
{
"type": "range_usd",
"min_usd": $inputMin + $outputPrice,
"max_usd": $inputMax + $outputPrice
}
)
""",
),
)
@classmethod
async def execute(
cls,
model: str,
prompt: str,
negative_prompt: str,
reference_videos: IO.Autogrow.Type,
size: str,
duration: int,
seed: int,
shot_type: str,
watermark: bool,
):
reference_video_urls = []
for i in reference_videos:
validate_video_duration(reference_videos[i], min_duration=2, max_duration=30)
for i in reference_videos:
reference_video_urls.append(await upload_video_to_comfyapi(cls, reference_videos[i]))
width, height = RES_IN_PARENS.search(size).groups()
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", method="POST"),
response_model=TaskCreationResponse,
data=Reference2VideoTaskCreationRequest(
model=model,
input=Reference2VideoInputField(
prompt=prompt, negative_prompt=negative_prompt, reference_video_urls=reference_video_urls
),
parameters=Reference2VideoParametersField(
size=f"{width}*{height}",
duration=duration,
shot_type=shot_type,
watermark=watermark,
seed=seed,
),
),
)
if not initial_response.output:
raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}")
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
response_model=VideoTaskStatusResponse,
status_extractor=lambda x: x.output.task_status,
poll_interval=6,
max_poll_attempts=280,
)
return IO.NodeOutput(await download_url_to_video_output(response.output.video_url))
class WanApiExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -729,6 +931,7 @@ class WanApiExtension(ComfyExtension):
WanImageToImageApi,
WanTextToVideoApi,
WanImageToVideoApi,
WanReferenceVideoApi,
]

View File

@ -55,7 +55,7 @@ def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> to
def tensor_to_bytesio(
image: torch.Tensor,
name: str | None = None,
*,
total_pixels: int = 2048 * 2048,
mime_type: str = "image/png",
) -> BytesIO:
@ -75,7 +75,7 @@ def tensor_to_bytesio(
pil_image = tensor_to_pil(image, total_pixels=total_pixels)
img_binary = pil_to_bytesio(pil_image, mime_type=mime_type)
img_binary.name = f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}"
img_binary.name = f"{uuid.uuid4()}.{mimetype_to_extension(mime_type)}"
return img_binary

View File

@ -49,6 +49,7 @@ async def upload_images_to_comfyapi(
mime_type: str | None = None,
wait_label: str | None = "Uploading",
show_batch_index: bool = True,
total_pixels: int = 2048 * 2048,
) -> list[str]:
"""
Uploads images to ComfyUI API and returns download URLs.
@ -63,7 +64,7 @@ async def upload_images_to_comfyapi(
for idx in range(num_to_upload):
tensor = image[idx] if is_batch else image
img_io = tensor_to_bytesio(tensor, mime_type=mime_type)
img_io = tensor_to_bytesio(tensor, total_pixels=total_pixels, mime_type=mime_type)
effective_label = wait_label
if wait_label and show_batch_index and num_to_upload > 1:
@ -81,7 +82,6 @@ async def upload_audio_to_comfyapi(
container_format: str = "mp4",
codec_name: str = "aac",
mime_type: str = "audio/mp4",
filename: str = "uploaded_audio.mp4",
) -> str:
"""
Uploads a single audio input to ComfyUI API and returns its download URL.
@ -91,7 +91,7 @@ async def upload_audio_to_comfyapi(
waveform: torch.Tensor = audio["waveform"]
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name)
return await upload_file_to_comfyapi(cls, audio_bytes_io, filename, mime_type)
return await upload_file_to_comfyapi(cls, audio_bytes_io, f"{uuid.uuid4()}.{container_format}", mime_type)
async def upload_video_to_comfyapi(
@ -119,7 +119,7 @@ async def upload_video_to_comfyapi(
raise ValueError(f"Could not verify video duration from source: {e}") from e
upload_mime_type = f"video/{container.value.lower()}"
filename = f"uploaded_video.{container.value.lower()}"
filename = f"{uuid.uuid4()}.{container.value.lower()}"
# Convert VideoInput to BytesIO using specified container/codec
video_bytes_io = BytesIO()

View File

@ -207,15 +207,15 @@ class ExecutionList(TopologicalSort):
return self.output_cache.get(node_id) is not None
def cache_link(self, from_node_id, to_node_id):
if not to_node_id in self.execution_cache:
if to_node_id not in self.execution_cache:
self.execution_cache[to_node_id] = {}
self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id)
if not from_node_id in self.execution_cache_listeners:
if from_node_id not in self.execution_cache_listeners:
self.execution_cache_listeners[from_node_id] = set()
self.execution_cache_listeners[from_node_id].add(to_node_id)
def get_cache(self, from_node_id, to_node_id):
if not to_node_id in self.execution_cache:
if to_node_id not in self.execution_cache:
return None
value = self.execution_cache[to_node_id].get(from_node_id)
if value is None:

View File

@ -14,8 +14,9 @@ class JobStatus:
IN_PROGRESS = 'in_progress'
COMPLETED = 'completed'
FAILED = 'failed'
CANCELLED = 'cancelled'
ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED]
ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED, CANCELLED]
# Media types that can be previewed in the frontend
@ -94,12 +95,6 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs:
status_info = history_item.get('status', {})
status_str = status_info.get('status_str') if status_info else None
if status_str == 'success':
status = JobStatus.COMPLETED
elif status_str == 'error':
status = JobStatus.FAILED
else:
status = JobStatus.COMPLETED
outputs = history_item.get('outputs', {})
outputs_count, preview_output = get_outputs_summary(outputs)
@ -107,6 +102,7 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs:
execution_error = None
execution_start_time = None
execution_end_time = None
was_interrupted = False
if status_info:
messages = status_info.get('messages', [])
for entry in messages:
@ -119,6 +115,15 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs:
execution_end_time = event_data.get('timestamp')
if event_name == 'execution_error':
execution_error = event_data
elif event_name == 'execution_interrupted':
was_interrupted = True
if status_str == 'success':
status = JobStatus.COMPLETED
elif status_str == 'error':
status = JobStatus.CANCELLED if was_interrupted else JobStatus.FAILED
else:
status = JobStatus.COMPLETED
job = prune_dict({
'id': prompt_id,
@ -268,13 +273,13 @@ def get_all_jobs(
for item in queued:
jobs.append(normalize_queue_item(item, JobStatus.PENDING))
include_completed = JobStatus.COMPLETED in status_filter
include_failed = JobStatus.FAILED in status_filter
if include_completed or include_failed:
history_statuses = {JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED}
requested_history_statuses = history_statuses & set(status_filter)
if requested_history_statuses:
for prompt_id, history_item in history.items():
is_failed = history_item.get('status', {}).get('status_str') == 'error'
if (is_failed and include_failed) or (not is_failed and include_completed):
jobs.append(normalize_history_item(prompt_id, history_item))
job = normalize_history_item(prompt_id, history_item)
if job.get('status') in requested_history_statuses:
jobs.append(job)
if workflow_id:
jobs = [j for j in jobs if j.get('workflow_id') == workflow_id]

View File

@ -55,7 +55,8 @@ class APG(io.ComfyNode):
def pre_cfg_function(args):
nonlocal running_avg, prev_sigma
if len(args["conds_out"]) == 1: return args["conds_out"]
if len(args["conds_out"]) == 1:
return args["conds_out"]
cond = args["conds_out"][0]
uncond = args["conds_out"][1]

View File

@ -112,7 +112,7 @@ class VAEDecodeAudio(IO.ComfyNode):
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
std[std < 1.0] = 1.0
audio /= std
return IO.NodeOutput({"waveform": audio, "sample_rate": 44100})
return IO.NodeOutput({"waveform": audio, "sample_rate": 44100 if "sample_rate" not in samples else samples["sample_rate"]})
decode = execute # TODO: remove
@ -399,6 +399,58 @@ class SplitAudioChannels(IO.ComfyNode):
separate = execute # TODO: remove
class JoinAudioChannels(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="JoinAudioChannels",
display_name="Join Audio Channels",
description="Joins left and right mono audio channels into a stereo audio.",
category="audio",
inputs=[
IO.Audio.Input("audio_left"),
IO.Audio.Input("audio_right"),
],
outputs=[
IO.Audio.Output(display_name="audio"),
],
)
@classmethod
def execute(cls, audio_left, audio_right) -> IO.NodeOutput:
waveform_left = audio_left["waveform"]
sample_rate_left = audio_left["sample_rate"]
waveform_right = audio_right["waveform"]
sample_rate_right = audio_right["sample_rate"]
if waveform_left.shape[1] != 1 or waveform_right.shape[1] != 1:
raise ValueError("AudioJoin: Both input audios must be mono.")
# Handle different sample rates by resampling to the higher rate
waveform_left, waveform_right, output_sample_rate = match_audio_sample_rates(
waveform_left, sample_rate_left, waveform_right, sample_rate_right
)
# Handle different lengths by trimming to the shorter length
length_left = waveform_left.shape[-1]
length_right = waveform_right.shape[-1]
if length_left != length_right:
min_length = min(length_left, length_right)
if length_left > min_length:
logging.info(f"JoinAudioChannels: Trimming left channel from {length_left} to {min_length} samples.")
waveform_left = waveform_left[..., :min_length]
if length_right > min_length:
logging.info(f"JoinAudioChannels: Trimming right channel from {length_right} to {min_length} samples.")
waveform_right = waveform_right[..., :min_length]
# Join the channels into stereo
left_channel = waveform_left[..., 0:1, :]
right_channel = waveform_right[..., 0:1, :]
stereo_waveform = torch.cat([left_channel, right_channel], dim=1)
return IO.NodeOutput({"waveform": stereo_waveform, "sample_rate": output_sample_rate})
def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2):
if sample_rate_1 != sample_rate_2:
@ -616,6 +668,7 @@ class AudioExtension(ComfyExtension):
RecordAudio,
TrimAudioDuration,
SplitAudioChannels,
JoinAudioChannels,
AudioConcat,
AudioMerge,
AudioAdjustVolume,

View File

@ -5,7 +5,9 @@ import comfy.model_management
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from comfy.ldm.hunyuan_video.upsampler import HunyuanVideo15SRModel
from comfy.ldm.lightricks.latent_upsampler import LatentUpsampler
import folder_paths
import json
class CLIPTextEncodeHunyuanDiT(io.ComfyNode):
@classmethod
@ -186,7 +188,7 @@ class LatentUpscaleModelLoader(io.ComfyNode):
@classmethod
def execute(cls, model_name) -> io.NodeOutput:
model_path = folder_paths.get_full_path_or_raise("latent_upscale_models", model_name)
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
sd, metadata = comfy.utils.load_torch_file(model_path, safe_load=True, return_metadata=True)
if "blocks.0.block.0.conv.weight" in sd:
config = {
@ -197,6 +199,8 @@ class LatentUpscaleModelLoader(io.ComfyNode):
"global_residual": False,
}
model_type = "720p"
model = HunyuanVideo15SRModel(model_type, config)
model.load_sd(sd)
elif "up.0.block.0.conv1.conv.weight" in sd:
sd = {key.replace("nin_shortcut", "nin_shortcut.conv", 1): value for key, value in sd.items()}
config = {
@ -205,9 +209,12 @@ class LatentUpscaleModelLoader(io.ComfyNode):
"block_out_channels": tuple(sd[f"up.{i}.block.0.conv1.conv.weight"].shape[0] for i in range(len([k for k in sd.keys() if k.startswith("up.") and k.endswith(".block.0.conv1.conv.weight")]))),
}
model_type = "1080p"
model = HunyuanVideo15SRModel(model_type, config)
model.load_sd(sd)
model = HunyuanVideo15SRModel(model_type, config)
model.load_sd(sd)
elif "post_upsample_res_blocks.0.conv2.bias" in sd:
config = json.loads(metadata["config"])
model = LatentUpsampler.from_config(config).to(dtype=comfy.model_management.vae_dtype(allowed_dtypes=[torch.bfloat16, torch.float32]))
model.load_state_dict(sd)
return io.NodeOutput(model)

View File

@ -0,0 +1,53 @@
import nodes
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension
class ImageCompare(IO.ComfyNode):
"""Compares two images with a slider interface."""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ImageCompare",
display_name="Image Compare",
description="Compares two images side by side with a slider.",
category="image",
is_experimental=True,
is_output_node=True,
inputs=[
IO.Image.Input("image_a", optional=True),
IO.Image.Input("image_b", optional=True),
IO.ImageCompare.Input("compare_view"),
],
outputs=[],
)
@classmethod
def execute(cls, image_a=None, image_b=None, compare_view=None) -> IO.NodeOutput:
result = {"a_images": [], "b_images": []}
preview_node = nodes.PreviewImage()
if image_a is not None and len(image_a) > 0:
saved = preview_node.save_images(image_a, "comfy.compare.a")
result["a_images"] = saved["ui"]["images"]
if image_b is not None and len(image_b) > 0:
saved = preview_node.save_images(image_b, "comfy.compare.b")
result["b_images"] = saved["ui"]["images"]
return IO.NodeOutput(ui=result)
class ImageCompareExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
ImageCompare,
]
async def comfy_entrypoint() -> ImageCompareExtension:
return ImageCompareExtension()

View File

@ -81,6 +81,59 @@ class LTXVImgToVideo(io.ComfyNode):
generate = execute # TODO: remove
class LTXVImgToVideoInplace(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LTXVImgToVideoInplace",
category="conditioning/video_models",
inputs=[
io.Vae.Input("vae"),
io.Image.Input("image"),
io.Latent.Input("latent"),
io.Float.Input("strength", default=1.0, min=0.0, max=1.0),
io.Boolean.Input("bypass", default=False, tooltip="Bypass the conditioning.")
],
outputs=[
io.Latent.Output(display_name="latent"),
],
)
@classmethod
def execute(cls, vae, image, latent, strength, bypass=False) -> io.NodeOutput:
if bypass:
return (latent,)
samples = latent["samples"]
_, height_scale_factor, width_scale_factor = (
vae.downscale_index_formula
)
batch, _, latent_frames, latent_height, latent_width = samples.shape
width = latent_width * width_scale_factor
height = latent_height * height_scale_factor
if image.shape[1] != height or image.shape[2] != width:
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
else:
pixels = image
encode_pixels = pixels[:, :, :, :3]
t = vae.encode(encode_pixels)
samples[:, :, :t.shape[2]] = t
conditioning_latent_frames_mask = torch.ones(
(batch, 1, latent_frames, 1, 1),
dtype=torch.float32,
device=samples.device,
)
conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength
return io.NodeOutput({"samples": samples, "noise_mask": conditioning_latent_frames_mask})
generate = execute # TODO: remove
def conditioning_get_any_value(conditioning, key, default=None):
for t in conditioning:
if key in t[1]:
@ -106,12 +159,12 @@ def get_keyframe_idxs(cond):
keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None)
if keyframe_idxs is None:
return None, 0
num_keyframes = torch.unique(keyframe_idxs[:, 0]).shape[0]
# keyframe_idxs contains start/end positions (last dimension), checking for unqiue values only for start
num_keyframes = torch.unique(keyframe_idxs[:, 0, :, 0]).shape[0]
return keyframe_idxs, num_keyframes
class LTXVAddGuide(io.ComfyNode):
NUM_PREFIX_FRAMES = 2
PATCHIFIER = SymmetricPatchifier(1)
PATCHIFIER = SymmetricPatchifier(1, start_end=True)
@classmethod
def define_schema(cls):
@ -182,26 +235,35 @@ class LTXVAddGuide(io.ComfyNode):
return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs})
@classmethod
def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors):
_, latent_idx = cls.get_latent_index(
cond=positive,
latent_length=latent_image.shape[2],
guide_length=guiding_latent.shape[2],
frame_idx=frame_idx,
scale_factors=scale_factors,
)
noise_mask[:, :, latent_idx:latent_idx + guiding_latent.shape[2]] = 1.0
def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors, guide_mask=None, in_channels=128):
if latent_image.shape[1] != in_channels or guiding_latent.shape[1] != in_channels:
raise ValueError("Adding guide to a combined AV latent is not supported.")
positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors)
negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors)
mask = torch.full(
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
1.0 - strength,
dtype=noise_mask.dtype,
device=noise_mask.device,
)
if guide_mask is not None:
target_h = max(noise_mask.shape[3], guide_mask.shape[3])
target_w = max(noise_mask.shape[4], guide_mask.shape[4])
if noise_mask.shape[3] == 1 or noise_mask.shape[4] == 1:
noise_mask = noise_mask.expand(-1, -1, -1, target_h, target_w)
if guide_mask.shape[3] == 1 or guide_mask.shape[4] == 1:
guide_mask = guide_mask.expand(-1, -1, -1, target_h, target_w)
mask = guide_mask - strength
else:
mask = torch.full(
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
1.0 - strength,
dtype=noise_mask.dtype,
device=noise_mask.device,
)
# This solves audio video combined latent case where latent_image has audio latent concatenated
# in channel dimension with video latent. The solution is to pad guiding latent accordingly.
if latent_image.shape[1] > guiding_latent.shape[1]:
pad_len = latent_image.shape[1] - guiding_latent.shape[1]
guiding_latent = torch.nn.functional.pad(guiding_latent, pad=(0, 0, 0, 0, 0, 0, 0, pad_len), value=0)
latent_image = torch.cat([latent_image, guiding_latent], dim=2)
noise_mask = torch.cat([noise_mask, mask], dim=2)
return positive, negative, latent_image, noise_mask
@ -238,33 +300,17 @@ class LTXVAddGuide(io.ComfyNode):
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
num_prefix_frames = min(cls.NUM_PREFIX_FRAMES, t.shape[2])
positive, negative, latent_image, noise_mask = cls.append_keyframe(
positive,
negative,
frame_idx,
latent_image,
noise_mask,
t[:, :, :num_prefix_frames],
t,
strength,
scale_factors,
)
latent_idx += num_prefix_frames
t = t[:, :, num_prefix_frames:]
if t.shape[2] == 0:
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
latent_image, noise_mask = cls.replace_latent_frames(
latent_image,
noise_mask,
t,
latent_idx,
strength,
)
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})
generate = execute # TODO: remove
@ -507,18 +553,90 @@ class LTXVPreprocess(io.ComfyNode):
preprocess = execute # TODO: remove
import comfy.nested_tensor
class LTXVConcatAVLatent(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LTXVConcatAVLatent",
category="latent/video/ltxv",
inputs=[
io.Latent.Input("video_latent"),
io.Latent.Input("audio_latent"),
],
outputs=[
io.Latent.Output(display_name="latent"),
],
)
@classmethod
def execute(cls, video_latent, audio_latent) -> io.NodeOutput:
output = {}
output.update(video_latent)
output.update(audio_latent)
video_noise_mask = video_latent.get("noise_mask", None)
audio_noise_mask = audio_latent.get("noise_mask", None)
if video_noise_mask is not None or audio_noise_mask is not None:
if video_noise_mask is None:
video_noise_mask = torch.ones_like(video_latent["samples"])
if audio_noise_mask is None:
audio_noise_mask = torch.ones_like(audio_latent["samples"])
output["noise_mask"] = comfy.nested_tensor.NestedTensor((video_noise_mask, audio_noise_mask))
output["samples"] = comfy.nested_tensor.NestedTensor((video_latent["samples"], audio_latent["samples"]))
return io.NodeOutput(output)
class LTXVSeparateAVLatent(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LTXVSeparateAVLatent",
category="latent/video/ltxv",
description="LTXV Separate AV Latent",
inputs=[
io.Latent.Input("av_latent"),
],
outputs=[
io.Latent.Output(display_name="video_latent"),
io.Latent.Output(display_name="audio_latent"),
],
)
@classmethod
def execute(cls, av_latent) -> io.NodeOutput:
latents = av_latent["samples"].unbind()
video_latent = av_latent.copy()
video_latent["samples"] = latents[0]
audio_latent = av_latent.copy()
audio_latent["samples"] = latents[1]
if "noise_mask" in av_latent:
masks = av_latent["noise_mask"]
if masks is not None:
masks = masks.unbind()
video_latent["noise_mask"] = masks[0]
audio_latent["noise_mask"] = masks[1]
return io.NodeOutput(video_latent, audio_latent)
class LtxvExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
EmptyLTXVLatentVideo,
LTXVImgToVideo,
LTXVImgToVideoInplace,
ModelSamplingLTXV,
LTXVConditioning,
LTXVScheduler,
LTXVAddGuide,
LTXVPreprocess,
LTXVCropGuides,
LTXVConcatAVLatent,
LTXVSeparateAVLatent,
]

View File

@ -0,0 +1,224 @@
import folder_paths
import comfy.utils
import comfy.model_management
import torch
from comfy.ldm.lightricks.vae.audio_vae import AudioVAE
from comfy_api.latest import ComfyExtension, io
class LTXVAudioVAELoader(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="LTXVAudioVAELoader",
display_name="LTXV Audio VAE Loader",
category="audio",
inputs=[
io.Combo.Input(
"ckpt_name",
options=folder_paths.get_filename_list("checkpoints"),
tooltip="Audio VAE checkpoint to load.",
)
],
outputs=[io.Vae.Output(display_name="Audio VAE")],
)
@classmethod
def execute(cls, ckpt_name: str) -> io.NodeOutput:
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
return io.NodeOutput(AudioVAE(sd, metadata))
class LTXVAudioVAEEncode(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="LTXVAudioVAEEncode",
display_name="LTXV Audio VAE Encode",
category="audio",
inputs=[
io.Audio.Input("audio", tooltip="The audio to be encoded."),
io.Vae.Input(
id="audio_vae",
display_name="Audio VAE",
tooltip="The Audio VAE model to use for encoding.",
),
],
outputs=[io.Latent.Output(display_name="Audio Latent")],
)
@classmethod
def execute(cls, audio, audio_vae: AudioVAE) -> io.NodeOutput:
audio_latents = audio_vae.encode(audio)
return io.NodeOutput(
{
"samples": audio_latents,
"sample_rate": int(audio_vae.sample_rate),
"type": "audio",
}
)
class LTXVAudioVAEDecode(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="LTXVAudioVAEDecode",
display_name="LTXV Audio VAE Decode",
category="audio",
inputs=[
io.Latent.Input("samples", tooltip="The latent to be decoded."),
io.Vae.Input(
id="audio_vae",
display_name="Audio VAE",
tooltip="The Audio VAE model used for decoding the latent.",
),
],
outputs=[io.Audio.Output(display_name="Audio")],
)
@classmethod
def execute(cls, samples, audio_vae: AudioVAE) -> io.NodeOutput:
audio_latent = samples["samples"]
if audio_latent.is_nested:
audio_latent = audio_latent.unbind()[-1]
audio = audio_vae.decode(audio_latent).to(audio_latent.device)
output_audio_sample_rate = audio_vae.output_sample_rate
return io.NodeOutput(
{
"waveform": audio,
"sample_rate": int(output_audio_sample_rate),
}
)
class LTXVEmptyLatentAudio(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="LTXVEmptyLatentAudio",
display_name="LTXV Empty Latent Audio",
category="latent/audio",
inputs=[
io.Int.Input(
"frames_number",
default=97,
min=1,
max=1000,
step=1,
display_mode=io.NumberDisplay.number,
tooltip="Number of frames.",
),
io.Int.Input(
"frame_rate",
default=25,
min=1,
max=1000,
step=1,
display_mode=io.NumberDisplay.number,
tooltip="Number of frames per second.",
),
io.Int.Input(
"batch_size",
default=1,
min=1,
max=4096,
display_mode=io.NumberDisplay.number,
tooltip="The number of latent audio samples in the batch.",
),
io.Vae.Input(
id="audio_vae",
display_name="Audio VAE",
tooltip="The Audio VAE model to get configuration from.",
),
],
outputs=[io.Latent.Output(display_name="Latent")],
)
@classmethod
def execute(
cls,
frames_number: int,
frame_rate: int,
batch_size: int,
audio_vae: AudioVAE,
) -> io.NodeOutput:
"""Generate empty audio latents matching the reference pipeline structure."""
assert audio_vae is not None, "Audio VAE model is required"
z_channels = audio_vae.latent_channels
audio_freq = audio_vae.latent_frequency_bins
sampling_rate = int(audio_vae.sample_rate)
num_audio_latents = audio_vae.num_of_latents_from_frames(frames_number, frame_rate)
audio_latents = torch.zeros(
(batch_size, z_channels, num_audio_latents, audio_freq),
device=comfy.model_management.intermediate_device(),
)
return io.NodeOutput(
{
"samples": audio_latents,
"sample_rate": sampling_rate,
"type": "audio",
}
)
class LTXAVTextEncoderLoader(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="LTXAVTextEncoderLoader",
display_name="LTXV Audio Text Encoder Loader",
category="advanced/loaders",
description="[Recipes]\n\nltxav: gemma 3 12B",
inputs=[
io.Combo.Input(
"text_encoder",
options=folder_paths.get_filename_list("text_encoders"),
),
io.Combo.Input(
"ckpt_name",
options=folder_paths.get_filename_list("checkpoints"),
),
io.Combo.Input(
"device",
options=["default", "cpu"],
)
],
outputs=[io.Clip.Output()],
)
@classmethod
def execute(cls, text_encoder, ckpt_name, device="default"):
clip_type = comfy.sd.CLIPType.LTXV
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", text_encoder)
clip_path2 = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
model_options = {}
if device == "cpu":
model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
return io.NodeOutput(clip)
class LTXVAudioExtension(ComfyExtension):
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
LTXVAudioVAELoader,
LTXVAudioVAEEncode,
LTXVAudioVAEDecode,
LTXVEmptyLatentAudio,
LTXAVTextEncoderLoader,
]
async def comfy_entrypoint() -> ComfyExtension:
return LTXVAudioExtension()

View File

@ -0,0 +1,75 @@
from comfy import model_management
import math
class LTXVLatentUpsampler:
"""
Upsamples a video latent by a factor of 2.
"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"samples": ("LATENT",),
"upscale_model": ("LATENT_UPSCALE_MODEL",),
"vae": ("VAE",),
}
}
RETURN_TYPES = ("LATENT",)
FUNCTION = "upsample_latent"
CATEGORY = "latent/video"
EXPERIMENTAL = True
def upsample_latent(
self,
samples: dict,
upscale_model,
vae,
) -> tuple:
"""
Upsample the input latent using the provided model.
Args:
samples (dict): Input latent samples
upscale_model (LatentUpsampler): Loaded upscale model
vae: VAE model for normalization
auto_tiling (bool): Whether to automatically tile the input for processing
Returns:
tuple: Tuple containing the upsampled latent
"""
device = model_management.get_torch_device()
memory_required = model_management.module_size(upscale_model)
model_dtype = next(upscale_model.parameters()).dtype
latents = samples["samples"]
input_dtype = latents.dtype
memory_required += math.prod(latents.shape) * 3000.0 # TODO: more accurate
model_management.free_memory(memory_required, device)
try:
upscale_model.to(device) # TODO: use the comfy model management system.
latents = latents.to(dtype=model_dtype, device=device)
"""Upsample latents without tiling."""
latents = vae.first_stage_model.per_channel_statistics.un_normalize(latents)
upsampled_latents = upscale_model(latents)
finally:
upscale_model.cpu()
upsampled_latents = vae.first_stage_model.per_channel_statistics.normalize(
upsampled_latents
)
upsampled_latents = upsampled_latents.to(dtype=input_dtype, device=model_management.intermediate_device())
return_dict = samples.copy()
return_dict["samples"] = upsampled_latents
return_dict.pop("noise_mask", None)
return (return_dict,)
NODE_CLASS_MAPPINGS = {
"LTXVLatentUpsampler": LTXVLatentUpsampler,
}

View File

@ -10,7 +10,7 @@ class Mahiro(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="Mahiro",
display_name="Mahiro is so cute that she deserves a better guidance function!! (。・ω・。)",
display_name="Mahiro CFG",
category="_for_testing",
description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.",
inputs=[

View File

@ -244,6 +244,10 @@ class ModelPatchLoader:
elif 'control_all_x_embedder.2-1.weight' in sd: # alipai z image fun controlnet
sd = z_image_convert(sd)
config = {}
if 'control_layers.4.adaLN_modulation.0.weight' not in sd:
config['n_control_layers'] = 3
config['additional_in_dim'] = 17
config['refiner_control'] = True
if 'control_layers.14.adaLN_modulation.0.weight' in sd:
config['n_control_layers'] = 15
config['additional_in_dim'] = 17

View File

@ -78,18 +78,20 @@ class ImageUpscaleWithModel(io.ComfyNode):
overlap = 32
oom = True
while oom:
try:
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
pbar = comfy.utils.ProgressBar(steps)
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
oom = False
except model_management.OOM_EXCEPTION as e:
tile //= 2
if tile < 128:
raise e
try:
while oom:
try:
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
pbar = comfy.utils.ProgressBar(steps)
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
oom = False
except model_management.OOM_EXCEPTION as e:
tile //= 2
if tile < 128:
raise e
finally:
upscale_model.to("cpu")
upscale_model.to("cpu")
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
return io.NodeOutput(s)

View File

@ -817,7 +817,7 @@ def get_sample_indices(original_fps,
if required_duration > total_frames / original_fps:
raise ValueError("required_duration must be less than video length")
if not fixed_start is None and fixed_start >= 0:
if fixed_start is not None and fixed_start >= 0:
start_frame = fixed_start
else:
max_start = total_frames - required_origin_frames

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.7.0"
__version__ = "0.9.1"

View File

@ -601,6 +601,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number."
logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary()))
logging.error("Got an OOM, unloading all loaded models.")
comfy.model_management.unload_all_models()

Some files were not shown because too many files have changed in this diff Show More