Compare commits

...

48 Commits

Author SHA1 Message Date
comfyanonymous
5ac1372533 ComfyUI v0.9.1
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-01-13 01:44:06 -05:00
comfyanonymous
1dcbd9efaf
Bump ltxav mem estimation a bit. (#11842) 2026-01-13 01:42:07 -05:00
comfyanonymous
db9e6edfa1 ComfyUI v0.9.0 2026-01-13 01:23:31 -05:00
Christian Byrne
8af13b439b
Update requirements.txt (#11841) 2026-01-13 01:22:25 -05:00
Jedrzej Kosinski
acd0e53653
Make bulk_ops not use .returning to be compatible with python 3.10 and 3.11 sqlalchemy (#11839) 2026-01-13 00:15:24 -05:00
comfyanonymous
117e7a5853
Refactor to try to lower mem usage. (#11840) 2026-01-12 21:01:52 -08:00
comfyanonymous
b3c0e4de57
Make loras work on nvfp4 models. (#11837)
The initial applying is a bit slow but will probably be sped up in the
future.
2026-01-12 22:33:54 -05:00
ComfyUI Wiki
ecaeeb990d
chore: update workflow templates to v0.8.4 (#11835) 2026-01-12 19:18:01 -08:00
ComfyUI Wiki
c2b65e2fce
Update workflow templates to v0.8.0 (#11828)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-01-12 17:29:25 -05:00
Jukka Seppänen
fd5c0755af
Reduce LTX2 VRAM use by more efficient timestep embed handling (#11829) 2026-01-12 17:28:59 -05:00
comfyanonymous
c881a1d689
Support the siglip 2 naflex model as a clip vision model. (#11831)
Not useful yet.
2026-01-12 17:05:54 -05:00
kelseyee
a3b5d4996a
Support ModelScope-Trainer DiffSynth lora for Z Image. (#11805) 2026-01-12 15:38:46 -05:00
comfyanonymous
c6238047ee
Put more details about portable in readme. (#11816)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-01-11 21:11:53 -05:00
Alexander Piskun
5cd1113236
fix(api-nodes): use a unique name for uploading audio files (#11778)
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Generate Pydantic Stubs from api.comfy.org / generate-models (push) Has been cancelled
2026-01-11 03:07:11 -08:00
comfyanonymous
2f642d5d9b
Fix chroma fp8 te being treated as fp16. (#11795)
Some checks are pending
Python Linting / Run Pylint (push) Waiting to run
Python Linting / Run Ruff (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
2026-01-10 14:40:42 -08:00
comfyanonymous
cd912963f1
Fix issue with t5 text encoder in fp4. (#11794) 2026-01-10 17:31:31 -05:00
DELUXA
6e4b1f9d00
pythorch_attn_by_def_on_gfx1200 (#11793) 2026-01-10 16:51:05 -05:00
comfyanonymous
dc202a2e51
Properly save mixed ops. (#11772)
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
2026-01-10 02:03:57 -05:00
ComfyUI Wiki
153bc524bf
chore: update embedded docs to v0.4.0 (#11776) 2026-01-10 01:29:30 -05:00
Alexander Piskun
393d2880dd
feat(api-nodes): added nodes for Vidu2 (#11760)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
2026-01-09 12:59:38 -08:00
Alexander Piskun
4484b93d61
fix(api-nodes): do not downscale the input image for Topaz Enhance (#11768) 2026-01-09 12:25:56 -08:00
comfyanonymous
bd0e6825e8
Be less strict when loading mixed ops weights. (#11769) 2026-01-09 14:21:06 -05:00
Jedrzej Kosinski
ec0a832acb
Add workaround for hacky nodepack(s) that edit folder_names_and_paths to have values with tuples of more than 2. Other things could potentially break with those nodepack(s), so I will hunt for the guilty nodepack(s) now. (#11755)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-01-08 22:49:12 -08:00
ric-yu
04c49a29b4
feat: add cancelled filter to /jobs (#11680) 2026-01-08 21:57:36 -08:00
Terry Jia
4609fcd260
add node - image compare (#11343)
Some checks failed
Python Linting / Run Pylint (push) Waiting to run
Python Linting / Run Ruff (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
2026-01-08 21:31:19 -08:00
rattus
6207f86c18
Fix VAEEncodeForInpaint to support WAN VAE tuple downscale_ratio (#11572)
Use vae.spacial_compression_encode() instead of directly accessing
downscale_ratio to handle both standard VAEs (int) and WAN VAEs (tuple).

Addresses reviewer feedback on PR #11259.

Co-authored-by: ChrisFab16 <christopher@fabritius.dk>
2026-01-08 23:34:48 -05:00
Jedrzej Kosinski
1dc3da6314
Add most basic Asset support for models (#11315)
* Brought over minimal elements from PR 10045 to reproduce seed_assets and register_assets_system without adding anything to the DB or server routes yet, for now making everything sync (can introduce async once everything is cleaned up and brought over)

* Added db script to insert assets stuff, cleaned up some code; assets (models) now get added/rescanned

* Added support for 5 http endpoints for assets

* Replaced Optional with | None in schemas_in.py and schemas_out.py

* Remove two routes that will not be relevant yet in this PR: HEAD /api/assets/hash/<hash> and PUT /api/assets/<id>/preview

* Remove some functions the two deleted endpoints were using

* Don't show assets scan message upon calling /object_info endpoint

* removed unsued import to satisfy ruff

* Simplified hashing function tpye hint and _hash_file_obj

* Satisfied ruff
2026-01-08 22:21:51 -05:00
Comfy Org PR Bot
114fc73685
Bump comfyui-frontend-package to 1.36.13 (#11645) 2026-01-08 22:16:15 -05:00
comfyanonymous
b48d6a83d4
Fix csp error in frontend when forcing offline. (#11749) 2026-01-08 22:15:50 -05:00
Jukka Seppänen
027042db68
Add node: JoinAudioChannels (#11728) 2026-01-08 22:14:06 -05:00
comfyanonymous
1a20656448
Fix import issue. (#11746) 2026-01-08 17:23:59 -05:00
comfyanonymous
0f11869d55
Better detection if AMD torch compiled with efficient attention. (#11745) 2026-01-08 17:16:58 -05:00
Dr.Lt.Data
5943fbf457
bump comfyui_manager version to the 4.0.5 (#11732)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-01-08 08:15:42 -08:00
Yoland Yan
a60b7b86c5
Revert "Force sequential execution in CI test jobs (#11687)" (#11725)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
This reverts commit ce0000c4f2.
2026-01-07 21:41:57 -08:00
comfyanonymous
2e9d51680a ComfyUI version v0.8.2
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-01-07 23:50:02 -05:00
comfyanonymous
50d6e1caf4
Tweak ltxv vae mem estimation. (#11722) 2026-01-07 23:07:05 -05:00
comfyanonymous
ac12f77bed ComfyUI version v0.8.1 2026-01-07 22:10:08 -05:00
ComfyUI Wiki
fcd9a236b0
Update template to 0.7.69 (#11719) 2026-01-07 18:22:23 -08:00
comfyanonymous
21e8425087
Add warning for old pytorch. (#11718) 2026-01-07 21:07:26 -05:00
rattus
b6c79a648a
ops: Fix offloading with FP8MM performance (#11697)
This logic was checking comfy_cast_weights, and going straight to
to the forward_comfy_cast_weights implementation without
attempting to downscale input to fp8 in the event comfy_cast_weights
is set.

The main reason comfy_cast_weights would be set would be for async
offload, which is not a good reason to nix FP8MM.

So instead, and together the underlying exclusions for FP8MM which
are:

* having a weight_function (usually LowVramPatch)
* force_cast_weights (compute dtype override)
* the weight is not Quantized
* the input is already quantized
* the model or layer has MM explictily disabled.

If you get past all of those exclusions, quantize the input tensor.
Then hand the new input, quantized or not off to
forward_comfy_cast_weights to handle it. If the weight is offloaded
but input is quantized you will get an offloaded MM8.
2026-01-07 21:01:16 -05:00
comfyanonymous
25bc1b5b57
Add memory estimation function to ltxav text encoder. (#11716) 2026-01-07 20:11:22 -05:00
comfyanonymous
3cd19e99c1
Increase ltxav mem estimation by a bit. (#11715) 2026-01-07 20:04:56 -05:00
comfyanonymous
007b87e7ac
Bump required comfy-kitchen version. (#11714) 2026-01-07 19:48:47 -05:00
comfyanonymous
34751fe9f9
Lower ltxv text encoder vram use. (#11713) 2026-01-07 19:12:15 -05:00
Jukka Seppänen
1c705f7bfb
Add device selection for LTXAVTextEncoderLoader (#11700) 2026-01-07 18:39:59 -05:00
rattus
48e5ea1dfd
model_patcher: Remove confusing load stat (#11710)
If the loader passes 1e32 as the usable memory size, it means force
the full load. This happens with CPU loads and a few other misc cases.
Removing the confusing number and just leave the other details.
2026-01-07 18:39:20 -05:00
comfyanonymous
3cd7b32f1b
Support gemma 12B with quant weights. (#11696)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-01-07 05:15:14 -05:00
comfyanonymous
c0c9720d77
Fix stable release workflow not pulling latest comfy kitchen. (#11695) 2026-01-07 04:48:28 -05:00
55 changed files with 3005 additions and 298 deletions

View File

@ -117,7 +117,7 @@ jobs:
./python.exe get-pip.py ./python.exe get-pip.py
./python.exe -s -m pip install ../${{ inputs.cache_tag }}_python_deps/* ./python.exe -s -m pip install ../${{ inputs.cache_tag }}_python_deps/*
grep comfyui ../ComfyUI/requirements.txt > ./requirements_comfyui.txt grep comfy ../ComfyUI/requirements.txt > ./requirements_comfyui.txt
./python.exe -s -m pip install -r requirements_comfyui.txt ./python.exe -s -m pip install -r requirements_comfyui.txt
rm requirements_comfyui.txt rm requirements_comfyui.txt

View File

@ -20,7 +20,6 @@ jobs:
test-stable: test-stable:
strategy: strategy:
fail-fast: false fail-fast: false
max-parallel: 1 # This forces sequential execution
matrix: matrix:
# os: [macos, linux, windows] # os: [macos, linux, windows]
# os: [macos, linux] # os: [macos, linux]
@ -75,7 +74,6 @@ jobs:
test-unix-nightly: test-unix-nightly:
strategy: strategy:
fail-fast: false fail-fast: false
max-parallel: 1 # This forces sequential execution
matrix: matrix:
# os: [macos, linux] # os: [macos, linux]
os: [linux] os: [linux]

View File

@ -183,7 +183,7 @@ Simply download, extract with [7-Zip](https://7-zip.org) or with the windows exp
If you have trouble extracting it, right click the file -> properties -> unblock If you have trouble extracting it, right click the file -> properties -> unblock
Update your Nvidia drivers if it doesn't start. The portable above currently comes with python 3.13 and pytorch cuda 13.0. Update your Nvidia drivers if it doesn't start.
#### Alternative Downloads: #### Alternative Downloads:
@ -212,7 +212,7 @@ Python 3.14 works but you may encounter issues with the torch compile node. The
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12 Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch unless it is less than 2 weeks old. torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old.
### Instructions: ### Instructions:

View File

@ -0,0 +1,174 @@
"""
Initial assets schema
Revision ID: 0001_assets
Revises: None
Create Date: 2025-12-10 00:00:00
"""
from alembic import op
import sqlalchemy as sa
revision = "0001_assets"
down_revision = None
branch_labels = None
depends_on = None
def upgrade() -> None:
# ASSETS: content identity
op.create_table(
"assets",
sa.Column("id", sa.String(length=36), primary_key=True),
sa.Column("hash", sa.String(length=256), nullable=True),
sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"),
sa.Column("mime_type", sa.String(length=255), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
)
op.create_index("uq_assets_hash", "assets", ["hash"], unique=True)
op.create_index("ix_assets_mime_type", "assets", ["mime_type"])
# ASSETS_INFO: user-visible references
op.create_table(
"assets_info",
sa.Column("id", sa.String(length=36), primary_key=True),
sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""),
sa.Column("name", sa.String(length=512), nullable=False),
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False),
sa.Column("preview_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="SET NULL"), nullable=True),
sa.Column("user_metadata", sa.JSON(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False),
sa.UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
)
op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"])
op.create_index("ix_assets_info_asset_id", "assets_info", ["asset_id"])
op.create_index("ix_assets_info_name", "assets_info", ["name"])
op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"])
op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"])
op.create_index("ix_assets_info_owner_name", "assets_info", ["owner_id", "name"])
# TAGS: normalized tag vocabulary
op.create_table(
"tags",
sa.Column("name", sa.String(length=512), primary_key=True),
sa.Column("tag_type", sa.String(length=32), nullable=False, server_default="user"),
sa.CheckConstraint("name = lower(name)", name="ck_tags_lowercase"),
)
op.create_index("ix_tags_tag_type", "tags", ["tag_type"])
# ASSET_INFO_TAGS: many-to-many for tags on AssetInfo
op.create_table(
"asset_info_tags",
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False),
sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"),
sa.Column("added_at", sa.DateTime(timezone=False), nullable=False),
sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"),
)
op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"])
op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"])
# ASSET_CACHE_STATE: N:1 local cache rows per Asset
op.create_table(
"asset_cache_state",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="CASCADE"), nullable=False),
sa.Column("file_path", sa.Text(), nullable=False), # absolute local path to cached file
sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
sa.Column("needs_verify", sa.Boolean(), nullable=False, server_default=sa.text("false")),
sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
)
op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"])
op.create_index("ix_asset_cache_state_asset_id", "asset_cache_state", ["asset_id"])
# ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting
op.create_table(
"asset_info_meta",
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
sa.Column("key", sa.String(length=256), nullable=False),
sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"),
sa.Column("val_str", sa.String(length=2048), nullable=True),
sa.Column("val_num", sa.Numeric(38, 10), nullable=True),
sa.Column("val_bool", sa.Boolean(), nullable=True),
sa.Column("val_json", sa.JSON(), nullable=True),
sa.PrimaryKeyConstraint("asset_info_id", "key", "ordinal", name="pk_asset_info_meta"),
)
op.create_index("ix_asset_info_meta_key", "asset_info_meta", ["key"])
op.create_index("ix_asset_info_meta_key_val_str", "asset_info_meta", ["key", "val_str"])
op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"])
op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"])
# Tags vocabulary
tags_table = sa.table(
"tags",
sa.column("name", sa.String(length=512)),
sa.column("tag_type", sa.String()),
)
op.bulk_insert(
tags_table,
[
{"name": "models", "tag_type": "system"},
{"name": "input", "tag_type": "system"},
{"name": "output", "tag_type": "system"},
{"name": "configs", "tag_type": "system"},
{"name": "checkpoints", "tag_type": "system"},
{"name": "loras", "tag_type": "system"},
{"name": "vae", "tag_type": "system"},
{"name": "text_encoders", "tag_type": "system"},
{"name": "diffusion_models", "tag_type": "system"},
{"name": "clip_vision", "tag_type": "system"},
{"name": "style_models", "tag_type": "system"},
{"name": "embeddings", "tag_type": "system"},
{"name": "diffusers", "tag_type": "system"},
{"name": "vae_approx", "tag_type": "system"},
{"name": "controlnet", "tag_type": "system"},
{"name": "gligen", "tag_type": "system"},
{"name": "upscale_models", "tag_type": "system"},
{"name": "hypernetworks", "tag_type": "system"},
{"name": "photomaker", "tag_type": "system"},
{"name": "classifiers", "tag_type": "system"},
{"name": "encoder", "tag_type": "system"},
{"name": "decoder", "tag_type": "system"},
{"name": "missing", "tag_type": "system"},
{"name": "rescan", "tag_type": "system"},
],
)
def downgrade() -> None:
op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta")
op.drop_table("asset_info_meta")
op.drop_index("ix_asset_cache_state_asset_id", table_name="asset_cache_state")
op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state")
op.drop_constraint("uq_asset_cache_state_file_path", table_name="asset_cache_state")
op.drop_table("asset_cache_state")
op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags")
op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags")
op.drop_table("asset_info_tags")
op.drop_index("ix_tags_tag_type", table_name="tags")
op.drop_table("tags")
op.drop_constraint("uq_assets_info_asset_owner_name", table_name="assets_info")
op.drop_index("ix_assets_info_owner_name", table_name="assets_info")
op.drop_index("ix_assets_info_last_access_time", table_name="assets_info")
op.drop_index("ix_assets_info_created_at", table_name="assets_info")
op.drop_index("ix_assets_info_name", table_name="assets_info")
op.drop_index("ix_assets_info_asset_id", table_name="assets_info")
op.drop_index("ix_assets_info_owner_id", table_name="assets_info")
op.drop_table("assets_info")
op.drop_index("uq_assets_hash", table_name="assets")
op.drop_index("ix_assets_mime_type", table_name="assets")
op.drop_table("assets")

102
app/assets/api/routes.py Normal file
View File

@ -0,0 +1,102 @@
import logging
import uuid
from aiohttp import web
from pydantic import ValidationError
import app.assets.manager as manager
from app import user_manager
from app.assets.api import schemas_in
from app.assets.helpers import get_query_dict
ROUTES = web.RouteTableDef()
USER_MANAGER: user_manager.UserManager | None = None
# UUID regex (canonical hyphenated form, case-insensitive)
UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None:
global USER_MANAGER
USER_MANAGER = user_manager_instance
app.add_routes(ROUTES)
def _error_response(status: int, code: str, message: str, details: dict | None = None) -> web.Response:
return web.json_response({"error": {"code": code, "message": message, "details": details or {}}}, status=status)
def _validation_error_response(code: str, ve: ValidationError) -> web.Response:
return _error_response(400, code, "Validation failed.", {"errors": ve.json()})
@ROUTES.get("/api/assets")
async def list_assets(request: web.Request) -> web.Response:
"""
GET request to list assets.
"""
query_dict = get_query_dict(request)
try:
q = schemas_in.ListAssetsQuery.model_validate(query_dict)
except ValidationError as ve:
return _validation_error_response("INVALID_QUERY", ve)
payload = manager.list_assets(
include_tags=q.include_tags,
exclude_tags=q.exclude_tags,
name_contains=q.name_contains,
metadata_filter=q.metadata_filter,
limit=q.limit,
offset=q.offset,
sort=q.sort,
order=q.order,
owner_id=USER_MANAGER.get_request_user_id(request),
)
return web.json_response(payload.model_dump(mode="json"))
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
async def get_asset(request: web.Request) -> web.Response:
"""
GET request to get an asset's info as JSON.
"""
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
result = manager.get_asset(
asset_info_id=asset_info_id,
owner_id=USER_MANAGER.get_request_user_id(request),
)
except ValueError as e:
return _error_response(404, "ASSET_NOT_FOUND", str(e), {"id": asset_info_id})
except Exception:
logging.exception(
"get_asset failed for asset_info_id=%s, owner_id=%s",
asset_info_id,
USER_MANAGER.get_request_user_id(request),
)
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.get("/api/tags")
async def get_tags(request: web.Request) -> web.Response:
"""
GET request to list all tags based on query parameters.
"""
query_map = dict(request.rel_url.query)
try:
query = schemas_in.TagsListQuery.model_validate(query_map)
except ValidationError as e:
return web.json_response(
{"error": {"code": "INVALID_QUERY", "message": "Invalid query parameters", "details": e.errors()}},
status=400,
)
result = manager.list_tags(
prefix=query.prefix,
limit=query.limit,
offset=query.offset,
order=query.order,
include_zero=query.include_zero,
owner_id=USER_MANAGER.get_request_user_id(request),
)
return web.json_response(result.model_dump(mode="json"))

View File

@ -0,0 +1,94 @@
import json
import uuid
from typing import Any, Literal
from pydantic import (
BaseModel,
ConfigDict,
Field,
conint,
field_validator,
)
class ListAssetsQuery(BaseModel):
include_tags: list[str] = Field(default_factory=list)
exclude_tags: list[str] = Field(default_factory=list)
name_contains: str | None = None
# Accept either a JSON string (query param) or a dict
metadata_filter: dict[str, Any] | None = None
limit: conint(ge=1, le=500) = 20
offset: conint(ge=0) = 0
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at"
order: Literal["asc", "desc"] = "desc"
@field_validator("include_tags", "exclude_tags", mode="before")
@classmethod
def _split_csv_tags(cls, v):
# Accept "a,b,c" or ["a","b"] (we are liberal in what we accept)
if v is None:
return []
if isinstance(v, str):
return [t.strip() for t in v.split(",") if t.strip()]
if isinstance(v, list):
out: list[str] = []
for item in v:
if isinstance(item, str):
out.extend([t.strip() for t in item.split(",") if t.strip()])
return out
return v
@field_validator("metadata_filter", mode="before")
@classmethod
def _parse_metadata_json(cls, v):
if v is None or isinstance(v, dict):
return v
if isinstance(v, str) and v.strip():
try:
parsed = json.loads(v)
except Exception as e:
raise ValueError(f"metadata_filter must be JSON: {e}") from e
if not isinstance(parsed, dict):
raise ValueError("metadata_filter must be a JSON object")
return parsed
return None
class TagsListQuery(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
prefix: str | None = Field(None, min_length=1, max_length=256)
limit: int = Field(100, ge=1, le=1000)
offset: int = Field(0, ge=0, le=10_000_000)
order: Literal["count_desc", "name_asc"] = "count_desc"
include_zero: bool = True
@field_validator("prefix")
@classmethod
def normalize_prefix(cls, v: str | None) -> str | None:
if v is None:
return v
v = v.strip()
return v.lower() or None
class SetPreviewBody(BaseModel):
"""Set or clear the preview for an AssetInfo. Provide an Asset.id or null."""
preview_id: str | None = None
@field_validator("preview_id", mode="before")
@classmethod
def _norm_uuid(cls, v):
if v is None:
return None
s = str(v).strip()
if not s:
return None
try:
uuid.UUID(s)
except Exception:
raise ValueError("preview_id must be a UUID")
return s

View File

@ -0,0 +1,60 @@
from datetime import datetime
from typing import Any
from pydantic import BaseModel, ConfigDict, Field, field_serializer
class AssetSummary(BaseModel):
id: str
name: str
asset_hash: str | None = None
size: int | None = None
mime_type: str | None = None
tags: list[str] = Field(default_factory=list)
preview_url: str | None = None
created_at: datetime | None = None
updated_at: datetime | None = None
last_access_time: datetime | None = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("created_at", "updated_at", "last_access_time")
def _ser_dt(self, v: datetime | None, _info):
return v.isoformat() if v else None
class AssetsList(BaseModel):
assets: list[AssetSummary]
total: int
has_more: bool
class AssetDetail(BaseModel):
id: str
name: str
asset_hash: str | None = None
size: int | None = None
mime_type: str | None = None
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
preview_id: str | None = None
created_at: datetime | None = None
last_access_time: datetime | None = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("created_at", "last_access_time")
def _ser_dt(self, v: datetime | None, _info):
return v.isoformat() if v else None
class TagUsage(BaseModel):
name: str
count: int
type: str
class TagsList(BaseModel):
tags: list[TagUsage] = Field(default_factory=list)
total: int
has_more: bool

View File

@ -0,0 +1,204 @@
import os
import uuid
import sqlalchemy
from typing import Iterable
from sqlalchemy.orm import Session
from sqlalchemy.dialects import sqlite
from app.assets.helpers import utcnow
from app.assets.database.models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, AssetInfoMeta
MAX_BIND_PARAMS = 800
def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]:
if not rows:
return []
rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row))
for i in range(0, len(rows), rows_per_stmt):
yield rows[i:i + rows_per_stmt]
def _iter_chunks(seq, n: int):
for i in range(0, len(seq), n):
yield seq[i:i + n]
def _rows_per_stmt(cols: int) -> int:
return max(1, MAX_BIND_PARAMS // max(1, cols))
def seed_from_paths_batch(
session: Session,
*,
specs: list[dict],
owner_id: str = "",
) -> dict:
"""Each spec is a dict with keys:
- abs_path: str
- size_bytes: int
- mtime_ns: int
- info_name: str
- tags: list[str]
- fname: Optional[str]
"""
if not specs:
return {"inserted_infos": 0, "won_states": 0, "lost_states": 0}
now = utcnow()
asset_rows: list[dict] = []
state_rows: list[dict] = []
path_to_asset: dict[str, str] = {}
asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row
path_list: list[str] = []
for sp in specs:
ap = os.path.abspath(sp["abs_path"])
aid = str(uuid.uuid4())
iid = str(uuid.uuid4())
path_list.append(ap)
path_to_asset[ap] = aid
asset_rows.append(
{
"id": aid,
"hash": None,
"size_bytes": sp["size_bytes"],
"mime_type": None,
"created_at": now,
}
)
state_rows.append(
{
"asset_id": aid,
"file_path": ap,
"mtime_ns": sp["mtime_ns"],
}
)
asset_to_info[aid] = {
"id": iid,
"owner_id": owner_id,
"name": sp["info_name"],
"asset_id": aid,
"preview_id": None,
"user_metadata": {"filename": sp["fname"]} if sp["fname"] else None,
"created_at": now,
"updated_at": now,
"last_access_time": now,
"_tags": sp["tags"],
"_filename": sp["fname"],
}
# insert all seed Assets (hash=NULL)
ins_asset = sqlite.insert(Asset)
for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)):
session.execute(ins_asset, chunk)
# try to claim AssetCacheState (file_path)
# Insert with ON CONFLICT DO NOTHING, then query to find which paths were actually inserted
ins_state = (
sqlite.insert(AssetCacheState)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)):
session.execute(ins_state, chunk)
# Query to find which of our paths won (were actually inserted)
winners_by_path: set[str] = set()
for chunk in _iter_chunks(path_list, MAX_BIND_PARAMS):
result = session.execute(
sqlalchemy.select(AssetCacheState.file_path)
.where(AssetCacheState.file_path.in_(chunk))
.where(AssetCacheState.asset_id.in_([path_to_asset[p] for p in chunk]))
)
winners_by_path.update(result.scalars().all())
all_paths_set = set(path_list)
losers_by_path = all_paths_set - winners_by_path
lost_assets = [path_to_asset[p] for p in losers_by_path]
if lost_assets: # losers get their Asset removed
for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS):
session.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(id_chunk)))
if not winners_by_path:
return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)}
# insert AssetInfo only for winners
# Insert with ON CONFLICT DO NOTHING, then query to find which were actually inserted
winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path]
ins_info = (
sqlite.insert(AssetInfo)
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
)
for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)):
session.execute(ins_info, chunk)
# Query to find which info rows were actually inserted (by matching our generated IDs)
all_info_ids = [row["id"] for row in winner_info_rows]
inserted_info_ids: set[str] = set()
for chunk in _iter_chunks(all_info_ids, MAX_BIND_PARAMS):
result = session.execute(
sqlalchemy.select(AssetInfo.id).where(AssetInfo.id.in_(chunk))
)
inserted_info_ids.update(result.scalars().all())
# build and insert tag + meta rows for the AssetInfo
tag_rows: list[dict] = []
meta_rows: list[dict] = []
if inserted_info_ids:
for row in winner_info_rows:
iid = row["id"]
if iid not in inserted_info_ids:
continue
for t in row["_tags"]:
tag_rows.append({
"asset_info_id": iid,
"tag_name": t,
"origin": "automatic",
"added_at": now,
})
if row["_filename"]:
meta_rows.append(
{
"asset_info_id": iid,
"key": "filename",
"ordinal": 0,
"val_str": row["_filename"],
"val_num": None,
"val_bool": None,
"val_json": None,
}
)
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS)
return {
"inserted_infos": len(inserted_info_ids),
"won_states": len(winners_by_path),
"lost_states": len(losers_by_path),
}
def bulk_insert_tags_and_meta(
session: Session,
*,
tag_rows: list[dict],
meta_rows: list[dict],
max_bind_params: int,
) -> None:
"""Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING.
- tag_rows keys: asset_info_id, tag_name, origin, added_at
- meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json
"""
if tag_rows:
ins_links = (
sqlite.insert(AssetInfoTag)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params):
session.execute(ins_links, chunk)
if meta_rows:
ins_meta = (
sqlite.insert(AssetInfoMeta)
.on_conflict_do_nothing(
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
)
)
for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params):
session.execute(ins_meta, chunk)

View File

@ -0,0 +1,233 @@
from __future__ import annotations
import uuid
from datetime import datetime
from typing import Any
from sqlalchemy import (
JSON,
BigInteger,
Boolean,
CheckConstraint,
DateTime,
ForeignKey,
Index,
Integer,
Numeric,
String,
Text,
UniqueConstraint,
)
from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship
from app.assets.helpers import utcnow
from app.database.models import to_dict, Base
class Asset(Base):
__tablename__ = "assets"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
hash: Mapped[str | None] = mapped_column(String(256), nullable=True)
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
mime_type: Mapped[str | None] = mapped_column(String(255))
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=utcnow
)
infos: Mapped[list[AssetInfo]] = relationship(
"AssetInfo",
back_populates="asset",
primaryjoin=lambda: Asset.id == foreign(AssetInfo.asset_id),
foreign_keys=lambda: [AssetInfo.asset_id],
cascade="all,delete-orphan",
passive_deletes=True,
)
preview_of: Mapped[list[AssetInfo]] = relationship(
"AssetInfo",
back_populates="preview_asset",
primaryjoin=lambda: Asset.id == foreign(AssetInfo.preview_id),
foreign_keys=lambda: [AssetInfo.preview_id],
viewonly=True,
)
cache_states: Mapped[list[AssetCacheState]] = relationship(
back_populates="asset",
cascade="all, delete-orphan",
passive_deletes=True,
)
__table_args__ = (
Index("uq_assets_hash", "hash", unique=True),
Index("ix_assets_mime_type", "mime_type"),
CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
return to_dict(self, include_none=include_none)
def __repr__(self) -> str:
return f"<Asset id={self.id} hash={(self.hash or '')[:12]}>"
class AssetCacheState(Base):
__tablename__ = "asset_cache_state"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False)
file_path: Mapped[str] = mapped_column(Text, nullable=False)
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
asset: Mapped[Asset] = relationship(back_populates="cache_states")
__table_args__ = (
Index("ix_asset_cache_state_file_path", "file_path"),
Index("ix_asset_cache_state_asset_id", "asset_id"),
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
return to_dict(self, include_none=include_none)
def __repr__(self) -> str:
return f"<AssetCacheState id={self.id} asset_id={self.asset_id} path={self.file_path!r}>"
class AssetInfo(Base):
__tablename__ = "assets_info"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
name: Mapped[str] = mapped_column(String(512), nullable=False)
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False)
preview_id: Mapped[str | None] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL"))
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON(none_as_null=True))
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
asset: Mapped[Asset] = relationship(
"Asset",
back_populates="infos",
foreign_keys=[asset_id],
lazy="selectin",
)
preview_asset: Mapped[Asset | None] = relationship(
"Asset",
back_populates="preview_of",
foreign_keys=[preview_id],
)
metadata_entries: Mapped[list[AssetInfoMeta]] = relationship(
back_populates="asset_info",
cascade="all,delete-orphan",
passive_deletes=True,
)
tag_links: Mapped[list[AssetInfoTag]] = relationship(
back_populates="asset_info",
cascade="all,delete-orphan",
passive_deletes=True,
overlaps="tags,asset_infos",
)
tags: Mapped[list[Tag]] = relationship(
secondary="asset_info_tags",
back_populates="asset_infos",
lazy="selectin",
viewonly=True,
overlaps="tag_links,asset_info_links,asset_infos,tag",
)
__table_args__ = (
UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
Index("ix_assets_info_owner_name", "owner_id", "name"),
Index("ix_assets_info_owner_id", "owner_id"),
Index("ix_assets_info_asset_id", "asset_id"),
Index("ix_assets_info_name", "name"),
Index("ix_assets_info_created_at", "created_at"),
Index("ix_assets_info_last_access_time", "last_access_time"),
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
data = to_dict(self, include_none=include_none)
data["tags"] = [t.name for t in self.tags]
return data
def __repr__(self) -> str:
return f"<AssetInfo id={self.id} name={self.name!r} asset_id={self.asset_id}>"
class AssetInfoMeta(Base):
__tablename__ = "asset_info_meta"
asset_info_id: Mapped[str] = mapped_column(
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
)
key: Mapped[str] = mapped_column(String(256), primary_key=True)
ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0)
val_str: Mapped[str | None] = mapped_column(String(2048), nullable=True)
val_num: Mapped[float | None] = mapped_column(Numeric(38, 10), nullable=True)
val_bool: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
val_json: Mapped[Any | None] = mapped_column(JSON(none_as_null=True), nullable=True)
asset_info: Mapped[AssetInfo] = relationship(back_populates="metadata_entries")
__table_args__ = (
Index("ix_asset_info_meta_key", "key"),
Index("ix_asset_info_meta_key_val_str", "key", "val_str"),
Index("ix_asset_info_meta_key_val_num", "key", "val_num"),
Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"),
)
class AssetInfoTag(Base):
__tablename__ = "asset_info_tags"
asset_info_id: Mapped[str] = mapped_column(
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
)
tag_name: Mapped[str] = mapped_column(
String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True
)
origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual")
added_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=utcnow
)
asset_info: Mapped[AssetInfo] = relationship(back_populates="tag_links")
tag: Mapped[Tag] = relationship(back_populates="asset_info_links")
__table_args__ = (
Index("ix_asset_info_tags_tag_name", "tag_name"),
Index("ix_asset_info_tags_asset_info_id", "asset_info_id"),
)
class Tag(Base):
__tablename__ = "tags"
name: Mapped[str] = mapped_column(String(512), primary_key=True)
tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user")
asset_info_links: Mapped[list[AssetInfoTag]] = relationship(
back_populates="tag",
overlaps="asset_infos,tags",
)
asset_infos: Mapped[list[AssetInfo]] = relationship(
secondary="asset_info_tags",
back_populates="tags",
viewonly=True,
overlaps="asset_info_links,tag_links,tags,asset_info",
)
__table_args__ = (
Index("ix_tags_tag_type", "tag_type"),
)
def __repr__(self) -> str:
return f"<Tag {self.name}>"

View File

@ -0,0 +1,267 @@
import sqlalchemy as sa
from collections import defaultdict
from sqlalchemy import select, exists, func
from sqlalchemy.orm import Session, contains_eager, noload
from app.assets.database.models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
from app.assets.helpers import escape_like_prefix, normalize_tags
from typing import Sequence
def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
"""Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
owner_id = (owner_id or "").strip()
if owner_id == "":
return AssetInfo.owner_id == ""
return AssetInfo.owner_id.in_(["", owner_id])
def apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
) -> sa.sql.Select:
"""include_tags: every tag must be present; exclude_tags: none may be present."""
include_tags = normalize_tags(include_tags)
exclude_tags = normalize_tags(exclude_tags)
if include_tags:
for tag_name in include_tags:
stmt = stmt.where(
exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name == tag_name)
)
)
if exclude_tags:
stmt = stmt.where(
~exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name.in_(exclude_tags))
)
)
return stmt
def apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: dict | None = None,
) -> sa.sql.Select:
"""Apply filters using asset_info_meta projection table."""
if not metadata_filter:
return stmt
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
return sa.exists().where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
*preds,
)
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
if value is None:
no_row_for_key = sa.not_(
sa.exists().where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
)
)
null_row = _exists_for_pred(
key,
AssetInfoMeta.val_json.is_(None),
AssetInfoMeta.val_str.is_(None),
AssetInfoMeta.val_num.is_(None),
AssetInfoMeta.val_bool.is_(None),
)
return sa.or_(no_row_for_key, null_row)
if isinstance(value, bool):
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
if isinstance(value, (int, float)):
from decimal import Decimal
num = value if isinstance(value, Decimal) else Decimal(str(value))
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
if isinstance(value, str):
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
for k, v in metadata_filter.items():
if isinstance(v, list):
ors = [_exists_clause_for_value(k, elem) for elem in v]
if ors:
stmt = stmt.where(sa.or_(*ors))
else:
stmt = stmt.where(_exists_clause_for_value(k, v))
return stmt
def asset_exists_by_hash(session: Session, asset_hash: str) -> bool:
"""
Check if an asset with a given hash exists in database.
"""
row = (
session.execute(
select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1)
)
).first()
return row is not None
def get_asset_info_by_id(session: Session, asset_info_id: str) -> AssetInfo | None:
return session.get(AssetInfo, asset_info_id)
def list_asset_infos_page(
session: Session,
owner_id: str = "",
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
) -> tuple[list[AssetInfo], dict[str, list[str]], int]:
base = (
select(AssetInfo)
.join(Asset, Asset.id == AssetInfo.asset_id)
.options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags))
.where(visible_owner_clause(owner_id))
)
if name_contains:
escaped, esc = escape_like_prefix(name_contains)
base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
base = apply_tag_filters(base, include_tags, exclude_tags)
base = apply_metadata_filter(base, metadata_filter)
sort = (sort or "created_at").lower()
order = (order or "desc").lower()
sort_map = {
"name": AssetInfo.name,
"created_at": AssetInfo.created_at,
"updated_at": AssetInfo.updated_at,
"last_access_time": AssetInfo.last_access_time,
"size": Asset.size_bytes,
}
sort_col = sort_map.get(sort, AssetInfo.created_at)
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
base = base.order_by(sort_exp).limit(limit).offset(offset)
count_stmt = (
select(sa.func.count())
.select_from(AssetInfo)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(visible_owner_clause(owner_id))
)
if name_contains:
escaped, esc = escape_like_prefix(name_contains)
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
total = int((session.execute(count_stmt)).scalar_one() or 0)
infos = (session.execute(base)).unique().scalars().all()
id_list: list[str] = [i.id for i in infos]
tag_map: dict[str, list[str]] = defaultdict(list)
if id_list:
rows = session.execute(
select(AssetInfoTag.asset_info_id, Tag.name)
.join(Tag, Tag.name == AssetInfoTag.tag_name)
.where(AssetInfoTag.asset_info_id.in_(id_list))
)
for aid, tag_name in rows.all():
tag_map[aid].append(tag_name)
return infos, tag_map, total
def fetch_asset_info_asset_and_tags(
session: Session,
asset_info_id: str,
owner_id: str = "",
) -> tuple[AssetInfo, Asset, list[str]] | None:
stmt = (
select(AssetInfo, Asset, Tag.name)
.join(Asset, Asset.id == AssetInfo.asset_id)
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
.where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
.options(noload(AssetInfo.tags))
.order_by(Tag.name.asc())
)
rows = (session.execute(stmt)).all()
if not rows:
return None
first_info, first_asset, _ = rows[0]
tags: list[str] = []
seen: set[str] = set()
for _info, _asset, tag_name in rows:
if tag_name and tag_name not in seen:
seen.add(tag_name)
tags.append(tag_name)
return first_info, first_asset, tags
def list_tags_with_usage(
session: Session,
prefix: str | None = None,
limit: int = 100,
offset: int = 0,
include_zero: bool = True,
order: str = "count_desc",
owner_id: str = "",
) -> tuple[list[tuple[str, str, int]], int]:
counts_sq = (
select(
AssetInfoTag.tag_name.label("tag_name"),
func.count(AssetInfoTag.asset_info_id).label("cnt"),
)
.select_from(AssetInfoTag)
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
.where(visible_owner_clause(owner_id))
.group_by(AssetInfoTag.tag_name)
.subquery()
)
q = (
select(
Tag.name,
Tag.tag_type,
func.coalesce(counts_sq.c.cnt, 0).label("count"),
)
.select_from(Tag)
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
)
if prefix:
escaped, esc = escape_like_prefix(prefix.strip().lower())
q = q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
if order == "name_asc":
q = q.order_by(Tag.name.asc())
else:
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
total_q = select(func.count()).select_from(Tag)
if prefix:
escaped, esc = escape_like_prefix(prefix.strip().lower())
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
total_q = total_q.where(
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
)
rows = (session.execute(q.limit(limit).offset(offset))).all()
total = (session.execute(total_q)).scalar_one()
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
return rows_norm, int(total or 0)

View File

@ -0,0 +1,62 @@
from typing import Iterable
import sqlalchemy
from sqlalchemy.orm import Session
from sqlalchemy.dialects import sqlite
from app.assets.helpers import normalize_tags, utcnow
from app.assets.database.models import Tag, AssetInfoTag, AssetInfo
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
wanted = normalize_tags(list(names))
if not wanted:
return
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
ins = (
sqlite.insert(Tag)
.values(rows)
.on_conflict_do_nothing(index_elements=[Tag.name])
)
return session.execute(ins)
def add_missing_tag_for_asset_id(
session: Session,
*,
asset_id: str,
origin: str = "automatic",
) -> None:
select_rows = (
sqlalchemy.select(
AssetInfo.id.label("asset_info_id"),
sqlalchemy.literal("missing").label("tag_name"),
sqlalchemy.literal(origin).label("origin"),
sqlalchemy.literal(utcnow()).label("added_at"),
)
.where(AssetInfo.asset_id == asset_id)
.where(
sqlalchemy.not_(
sqlalchemy.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing"))
)
)
)
session.execute(
sqlite.insert(AssetInfoTag)
.from_select(
["asset_info_id", "tag_name", "origin", "added_at"],
select_rows,
)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
def remove_missing_tag_for_asset_id(
session: Session,
*,
asset_id: str,
) -> None:
session.execute(
sqlalchemy.delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id.in_(sqlalchemy.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
AssetInfoTag.tag_name == "missing",
)
)

75
app/assets/hashing.py Normal file
View File

@ -0,0 +1,75 @@
from blake3 import blake3
from typing import IO
import os
import asyncio
DEFAULT_CHUNK = 8 * 1024 *1024 # 8MB
# NOTE: this allows hashing different representations of a file-like object
def blake3_hash(
fp: str | IO[bytes],
chunk_size: int = DEFAULT_CHUNK,
) -> str:
"""
Returns a BLAKE3 hex digest for ``fp``, which may be:
- a filename (str/bytes) or PathLike
- an open binary file object
If ``fp`` is a file object, it must be opened in **binary** mode and support
``read``, ``seek``, and ``tell``. The function will seek to the start before
reading and will attempt to restore the original position afterward.
"""
# duck typing to check if input is a file-like object
if hasattr(fp, "read"):
return _hash_file_obj(fp, chunk_size)
with open(os.fspath(fp), "rb") as f:
return _hash_file_obj(f, chunk_size)
async def blake3_hash_async(
fp: str | IO[bytes],
chunk_size: int = DEFAULT_CHUNK,
) -> str:
"""Async wrapper for ``blake3_hash_sync``.
Uses a worker thread so the event loop remains responsive.
"""
# If it is a path, open inside the worker thread to keep I/O off the loop.
if hasattr(fp, "read"):
return await asyncio.to_thread(blake3_hash, fp, chunk_size)
def _worker() -> str:
with open(os.fspath(fp), "rb") as f:
return _hash_file_obj(f, chunk_size)
return await asyncio.to_thread(_worker)
def _hash_file_obj(file_obj: IO, chunk_size: int = DEFAULT_CHUNK) -> str:
"""
Hash an already-open binary file object by streaming in chunks.
- Seeks to the beginning before reading (if supported).
- Restores the original position afterward (if tell/seek are supported).
"""
if chunk_size <= 0:
chunk_size = DEFAULT_CHUNK
# in case file object is already open and not at the beginning, track so can be restored after hashing
orig_pos = file_obj.tell()
try:
# seek to the beginning before reading
if orig_pos != 0:
file_obj.seek(0)
h = blake3()
while True:
chunk = file_obj.read(chunk_size)
if not chunk:
break
h.update(chunk)
return h.hexdigest()
finally:
# restore original position in file object, if needed
if orig_pos != 0:
file_obj.seek(orig_pos)

217
app/assets/helpers.py Normal file
View File

@ -0,0 +1,217 @@
import contextlib
import os
from aiohttp import web
from datetime import datetime, timezone
from pathlib import Path
from typing import Literal, Any
import folder_paths
RootType = Literal["models", "input", "output"]
ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output")
def get_query_dict(request: web.Request) -> dict[str, Any]:
"""
Gets a dictionary of query parameters from the request.
'request.query' is a MultiMapping[str], needs to be converted to a dictionary to be validated by Pydantic.
"""
query_dict = {
key: request.query.getall(key) if len(request.query.getall(key)) > 1 else request.query.get(key)
for key in request.query.keys()
}
return query_dict
def list_tree(base_dir: str) -> list[str]:
out: list[str] = []
base_abs = os.path.abspath(base_dir)
if not os.path.isdir(base_abs):
return out
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
for name in filenames:
out.append(os.path.abspath(os.path.join(dirpath, name)))
return out
def prefixes_for_root(root: RootType) -> list[str]:
if root == "models":
bases: list[str] = []
for _bucket, paths in get_comfy_models_folders():
bases.extend(paths)
return [os.path.abspath(p) for p in bases]
if root == "input":
return [os.path.abspath(folder_paths.get_input_directory())]
if root == "output":
return [os.path.abspath(folder_paths.get_output_directory())]
return []
def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]:
"""Escapes %, _ and the escape char itself in a LIKE prefix.
Returns (escaped_prefix, escape_char). Caller should append '%' and pass escape=escape_char to .like().
"""
s = s.replace(escape, escape + escape) # escape the escape char first
s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards
return s, escape
def fast_asset_file_check(
*,
mtime_db: int | None,
size_db: int | None,
stat_result: os.stat_result,
) -> bool:
if mtime_db is None:
return False
actual_mtime_ns = getattr(stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000))
if int(mtime_db) != int(actual_mtime_ns):
return False
sz = int(size_db or 0)
if sz > 0:
return int(stat_result.st_size) == sz
return True
def utcnow() -> datetime:
"""Naive UTC timestamp (no tzinfo). We always treat DB datetimes as UTC."""
return datetime.now(timezone.utc).replace(tzinfo=None)
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
"""Build a list of (folder_name, base_paths[]) categories that are configured for model locations.
We trust `folder_paths.folder_names_and_paths` and include a category if
*any* of its base paths lies under the Comfy `models_dir`.
"""
targets: list[tuple[str, list[str]]] = []
models_root = os.path.abspath(folder_paths.models_dir)
for name, values in folder_paths.folder_names_and_paths.items():
paths, _exts = values[0], values[1] # NOTE: this prevents nodepacks that hackily edit folder_... from breaking ComfyUI
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
targets.append((name, paths))
return targets
def compute_relative_filename(file_path: str) -> str | None:
"""
Return the model's path relative to the last well-known folder (the model category),
using forward slashes, eg:
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
For non-model paths, returns None.
NOTE: this is a temporary helper, used only for initializing metadata["filename"] field.
"""
try:
root_category, rel_path = get_relative_to_root_category_path_of_asset(file_path)
except ValueError:
return None
p = Path(rel_path)
parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
if not parts:
return None
if root_category == "models":
# parts[0] is the category ("checkpoints", "vae", etc) drop it
inside = parts[1:] if len(parts) > 1 else [parts[0]]
return "/".join(inside)
return "/".join(parts) # input/output: keep all parts
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
"""Given an absolute or relative file path, determine which root category the path belongs to:
- 'input' if the file resides under `folder_paths.get_input_directory()`
- 'output' if the file resides under `folder_paths.get_output_directory()`
- 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()`
Returns:
(root_category, relative_path_inside_that_root)
For 'models', the relative path is prefixed with the category name:
e.g. ('models', 'vae/test/sub/ae.safetensors')
Raises:
ValueError: if the path does not belong to input, output, or configured model bases.
"""
fp_abs = os.path.abspath(file_path)
def _is_within(child: str, parent: str) -> bool:
try:
return os.path.commonpath([child, parent]) == parent
except Exception:
return False
def _rel(child: str, parent: str) -> str:
return os.path.relpath(os.path.join(os.sep, os.path.relpath(child, parent)), os.sep)
# 1) input
input_base = os.path.abspath(folder_paths.get_input_directory())
if _is_within(fp_abs, input_base):
return "input", _rel(fp_abs, input_base)
# 2) output
output_base = os.path.abspath(folder_paths.get_output_directory())
if _is_within(fp_abs, output_base):
return "output", _rel(fp_abs, output_base)
# 3) models (check deepest matching base to avoid ambiguity)
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
for bucket, bases in get_comfy_models_folders():
for b in bases:
base_abs = os.path.abspath(b)
if not _is_within(fp_abs, base_abs):
continue
cand = (len(base_abs), bucket, _rel(fp_abs, base_abs))
if best is None or cand[0] > best[0]:
best = cand
if best is not None:
_, bucket, rel_inside = best
combined = os.path.join(bucket, rel_inside)
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
raise ValueError(f"Path is not within input, output, or configured model bases: {file_path}")
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
"""Return a tuple (name, tags) derived from a filesystem path.
Semantics:
- Root category is determined by `get_relative_to_root_category_path_of_asset`.
- The returned `name` is the base filename with extension from the relative path.
- The returned `tags` are:
[root_category] + parent folders of the relative path (in order)
For 'models', this means:
file '/.../ModelsDir/vae/test_tag/ae.safetensors'
-> root_category='models', some_path='vae/test_tag/ae.safetensors'
-> name='ae.safetensors', tags=['models', 'vae', 'test_tag']
Raises:
ValueError: if the path does not belong to input, output, or configured model bases.
"""
root_category, some_path = get_relative_to_root_category_path_of_asset(file_path)
p = Path(some_path)
parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)]
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
def normalize_tags(tags: list[str] | None) -> list[str]:
"""
Normalize a list of tags by:
- Stripping whitespace and converting to lowercase.
- Removing duplicates.
"""
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
def collect_models_files() -> list[str]:
out: list[str] = []
for folder_name, bases in get_comfy_models_folders():
rel_files = folder_paths.get_filename_list(folder_name) or []
for rel_path in rel_files:
abs_path = folder_paths.get_full_path(folder_name, rel_path)
if not abs_path:
continue
abs_path = os.path.abspath(abs_path)
allowed = False
for b in bases:
base_abs = os.path.abspath(b)
with contextlib.suppress(Exception):
if os.path.commonpath([abs_path, base_abs]) == base_abs:
allowed = True
break
if allowed:
out.append(abs_path)
return out

123
app/assets/manager.py Normal file
View File

@ -0,0 +1,123 @@
from typing import Sequence
from app.database.db import create_session
from app.assets.api import schemas_out
from app.assets.database.queries import (
asset_exists_by_hash,
fetch_asset_info_asset_and_tags,
list_asset_infos_page,
list_tags_with_usage,
)
def _safe_sort_field(requested: str | None) -> str:
if not requested:
return "created_at"
v = requested.lower()
if v in {"name", "created_at", "updated_at", "size", "last_access_time"}:
return v
return "created_at"
def asset_exists(asset_hash: str) -> bool:
with create_session() as session:
return asset_exists_by_hash(session, asset_hash=asset_hash)
def list_assets(
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
owner_id: str = "",
) -> schemas_out.AssetsList:
sort = _safe_sort_field(sort)
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
with create_session() as session:
infos, tag_map, total = list_asset_infos_page(
session,
owner_id=owner_id,
include_tags=include_tags,
exclude_tags=exclude_tags,
name_contains=name_contains,
metadata_filter=metadata_filter,
limit=limit,
offset=offset,
sort=sort,
order=order,
)
summaries: list[schemas_out.AssetSummary] = []
for info in infos:
asset = info.asset
tags = tag_map.get(info.id, [])
summaries.append(
schemas_out.AssetSummary(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset else None,
mime_type=asset.mime_type if asset else None,
tags=tags,
preview_url=f"/api/assets/{info.id}/content",
created_at=info.created_at,
updated_at=info.updated_at,
last_access_time=info.last_access_time,
)
)
return schemas_out.AssetsList(
assets=summaries,
total=total,
has_more=(offset + len(summaries)) < total,
)
def get_asset(asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail:
with create_session() as session:
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not res:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info, asset, tag_names = res
preview_id = info.preview_id
return schemas_out.AssetDetail(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
mime_type=asset.mime_type if asset else None,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
)
def list_tags(
prefix: str | None = None,
limit: int = 100,
offset: int = 0,
order: str = "count_desc",
include_zero: bool = True,
owner_id: str = "",
) -> schemas_out.TagsList:
limit = max(1, min(1000, limit))
offset = max(0, offset)
with create_session() as session:
rows, total = list_tags_with_usage(
session,
prefix=prefix,
limit=limit,
offset=offset,
include_zero=include_zero,
order=order,
owner_id=owner_id,
)
tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows]
return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total)

229
app/assets/scanner.py Normal file
View File

@ -0,0 +1,229 @@
import contextlib
import time
import logging
import os
import sqlalchemy
import folder_paths
from app.database.db import create_session, dependencies_available
from app.assets.helpers import (
collect_models_files, compute_relative_filename, fast_asset_file_check, get_name_and_tags_from_asset_path,
list_tree,prefixes_for_root, escape_like_prefix,
RootType
)
from app.assets.database.tags import add_missing_tag_for_asset_id, ensure_tags_exist, remove_missing_tag_for_asset_id
from app.assets.database.bulk_ops import seed_from_paths_batch
from app.assets.database.models import Asset, AssetCacheState, AssetInfo
def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> None:
"""
Scan the given roots and seed the assets into the database.
"""
if not dependencies_available():
if enable_logging:
logging.warning("Database dependencies not available, skipping assets scan")
return
t_start = time.perf_counter()
created = 0
skipped_existing = 0
paths: list[str] = []
try:
existing_paths: set[str] = set()
for r in roots:
try:
survivors: set[str] = _fast_db_consistency_pass(r, collect_existing_paths=True, update_missing_tags=True)
if survivors:
existing_paths.update(survivors)
except Exception as e:
logging.exception("fast DB scan failed for %s: %s", r, e)
if "models" in roots:
paths.extend(collect_models_files())
if "input" in roots:
paths.extend(list_tree(folder_paths.get_input_directory()))
if "output" in roots:
paths.extend(list_tree(folder_paths.get_output_directory()))
specs: list[dict] = []
tag_pool: set[str] = set()
for p in paths:
abs_p = os.path.abspath(p)
if abs_p in existing_paths:
skipped_existing += 1
continue
try:
stat_p = os.stat(abs_p, follow_symlinks=False)
except OSError:
continue
# skip empty files
if not stat_p.st_size:
continue
name, tags = get_name_and_tags_from_asset_path(abs_p)
specs.append(
{
"abs_path": abs_p,
"size_bytes": stat_p.st_size,
"mtime_ns": getattr(stat_p, "st_mtime_ns", int(stat_p.st_mtime * 1_000_000_000)),
"info_name": name,
"tags": tags,
"fname": compute_relative_filename(abs_p),
}
)
for t in tags:
tag_pool.add(t)
# if no file specs, nothing to do
if not specs:
return
with create_session() as sess:
if tag_pool:
ensure_tags_exist(sess, tag_pool, tag_type="user")
result = seed_from_paths_batch(sess, specs=specs, owner_id="")
created += result["inserted_infos"]
sess.commit()
finally:
if enable_logging:
logging.info(
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, total_seen=%d)",
roots,
time.perf_counter() - t_start,
created,
skipped_existing,
len(paths),
)
def _fast_db_consistency_pass(
root: RootType,
*,
collect_existing_paths: bool = False,
update_missing_tags: bool = False,
) -> set[str] | None:
"""Fast DB+FS pass for a root:
- Toggle needs_verify per state using fast check
- For hashed assets with at least one fast-ok state in this root: delete stale missing states
- For seed assets with all states missing: delete Asset and its AssetInfos
- Optionally add/remove 'missing' tags based on fast-ok in this root
- Optionally return surviving absolute paths
"""
prefixes = prefixes_for_root(root)
if not prefixes:
return set() if collect_existing_paths else None
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
escaped, esc = escape_like_prefix(base)
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
with create_session() as sess:
rows = (
sess.execute(
sqlalchemy.select(
AssetCacheState.id,
AssetCacheState.file_path,
AssetCacheState.mtime_ns,
AssetCacheState.needs_verify,
AssetCacheState.asset_id,
Asset.hash,
Asset.size_bytes,
)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(sqlalchemy.or_(*conds))
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
)
).all()
by_asset: dict[str, dict] = {}
for sid, fp, mtime_db, needs_verify, aid, a_hash, a_size in rows:
acc = by_asset.get(aid)
if acc is None:
acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []}
by_asset[aid] = acc
fast_ok = False
try:
exists = True
fast_ok = fast_asset_file_check(
mtime_db=mtime_db,
size_db=acc["size_db"],
stat_result=os.stat(fp, follow_symlinks=True),
)
except FileNotFoundError:
exists = False
except OSError:
exists = False
acc["states"].append({
"sid": sid,
"fp": fp,
"exists": exists,
"fast_ok": fast_ok,
"needs_verify": bool(needs_verify),
})
to_set_verify: list[int] = []
to_clear_verify: list[int] = []
stale_state_ids: list[int] = []
survivors: set[str] = set()
for aid, acc in by_asset.items():
a_hash = acc["hash"]
states = acc["states"]
any_fast_ok = any(s["fast_ok"] for s in states)
all_missing = all(not s["exists"] for s in states)
for s in states:
if not s["exists"]:
continue
if s["fast_ok"] and s["needs_verify"]:
to_clear_verify.append(s["sid"])
if not s["fast_ok"] and not s["needs_verify"]:
to_set_verify.append(s["sid"])
if a_hash is None:
if states and all_missing: # remove seed Asset completely, if no valid AssetCache exists
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id == aid))
asset = sess.get(Asset, aid)
if asset:
sess.delete(asset)
else:
for s in states:
if s["exists"]:
survivors.add(os.path.abspath(s["fp"]))
continue
if any_fast_ok: # if Asset has at least one valid AssetCache record, remove any invalid AssetCache records
for s in states:
if not s["exists"]:
stale_state_ids.append(s["sid"])
if update_missing_tags:
with contextlib.suppress(Exception):
remove_missing_tag_for_asset_id(sess, asset_id=aid)
elif update_missing_tags:
with contextlib.suppress(Exception):
add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
for s in states:
if s["exists"]:
survivors.add(os.path.abspath(s["fp"]))
if stale_state_ids:
sess.execute(sqlalchemy.delete(AssetCacheState).where(AssetCacheState.id.in_(stale_state_ids)))
if to_set_verify:
sess.execute(
sqlalchemy.update(AssetCacheState)
.where(AssetCacheState.id.in_(to_set_verify))
.values(needs_verify=True)
)
if to_clear_verify:
sess.execute(
sqlalchemy.update(AssetCacheState)
.where(AssetCacheState.id.in_(to_clear_verify))
.values(needs_verify=False)
)
sess.commit()
return survivors if collect_existing_paths else None

View File

@ -1,14 +1,21 @@
from sqlalchemy.orm import declarative_base from typing import Any
from datetime import datetime
from sqlalchemy.orm import DeclarativeBase
Base = declarative_base() class Base(DeclarativeBase):
pass
def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
def to_dict(obj):
fields = obj.__table__.columns.keys() fields = obj.__table__.columns.keys()
return { out: dict[str, Any] = {}
field: (val.to_dict() if hasattr(val, "to_dict") else val) for field in fields:
for field in fields val = getattr(obj, field)
if (val := getattr(obj, field)) if val is None and not include_none:
} continue
if isinstance(val, datetime):
out[field] = val.isoformat()
else:
out[field] = val
return out
# TODO: Define models here # TODO: Define models here

View File

@ -231,6 +231,7 @@ database_default_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db") os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
) )
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.") parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
parser.add_argument("--disable-assets-autoscan", action="store_true", help="Disable asset scanning on startup for database synchronization.")
if comfy.options.args_parsing: if comfy.options.args_parsing:
args = parser.parse_args() args = parser.parse_args()

View File

@ -1,6 +1,7 @@
import torch import torch
from comfy.ldm.modules.attention import optimized_attention_for_device from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.ops import comfy.ops
import math
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True): def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
image = image[:, :, :, :3] if image.shape[3] > 3 else image image = image[:, :, :, :3] if image.shape[3] > 3 else image
@ -21,6 +22,39 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s
image = torch.clip((255. * image), 0, 255).round() / 255.0 image = torch.clip((255. * image), 0, 255).round() / 255.0
return (image - mean.view([3,1,1])) / std.view([3,1,1]) return (image - mean.view([3,1,1])) / std.view([3,1,1])
def siglip2_flex_calc_resolution(oh, ow, patch_size, max_num_patches, eps=1e-5):
def scale_dim(size, scale):
scaled = math.ceil(size * scale / patch_size) * patch_size
return max(patch_size, int(scaled))
# Binary search for optimal scale
lo, hi = eps / 10, 100.0
while hi - lo >= eps:
mid = (lo + hi) / 2
h, w = scale_dim(oh, mid), scale_dim(ow, mid)
if (h // patch_size) * (w // patch_size) <= max_num_patches:
lo = mid
else:
hi = mid
return scale_dim(oh, lo), scale_dim(ow, lo)
def siglip2_preprocess(image, size, patch_size, num_patches, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True):
if size > 0:
return clip_preprocess(image, size=size, mean=mean, std=std, crop=crop)
image = image[:, :, :, :3] if image.shape[3] > 3 else image
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
std = torch.tensor(std, device=image.device, dtype=image.dtype)
image = image.movedim(-1, 1)
b, c, h, w = image.shape
h, w = siglip2_flex_calc_resolution(h, w, patch_size, num_patches)
image = torch.nn.functional.interpolate(image, size=(h, w), mode="bilinear", antialias=True)
image = torch.clip((255. * image), 0, 255).round() / 255.0
return (image - mean.view([3, 1, 1])) / std.view([3, 1, 1])
class CLIPAttention(torch.nn.Module): class CLIPAttention(torch.nn.Module):
def __init__(self, embed_dim, heads, dtype, device, operations): def __init__(self, embed_dim, heads, dtype, device, operations):
super().__init__() super().__init__()
@ -175,6 +209,27 @@ class CLIPTextModel(torch.nn.Module):
out = self.text_projection(x[2]) out = self.text_projection(x[2])
return (x[0], x[1], out, x[2]) return (x[0], x[1], out, x[2])
def siglip2_pos_embed(embed_weight, embeds, orig_shape):
embed_weight_len = round(embed_weight.shape[0] ** 0.5)
embed_weight = comfy.ops.cast_to_input(embed_weight, embeds).movedim(1, 0).reshape(1, -1, embed_weight_len, embed_weight_len)
embed_weight = torch.nn.functional.interpolate(embed_weight, size=orig_shape, mode="bilinear", align_corners=False, antialias=True)
embed_weight = embed_weight.reshape(-1, embed_weight.shape[-2] * embed_weight.shape[-1]).movedim(0, 1)
return embeds + embed_weight
class Siglip2Embeddings(torch.nn.Module):
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", num_patches=None, dtype=None, device=None, operations=None):
super().__init__()
self.patch_embedding = operations.Linear(num_channels * patch_size * patch_size, embed_dim, dtype=dtype, device=device)
self.position_embedding = operations.Embedding(num_patches, embed_dim, dtype=dtype, device=device)
self.patch_size = patch_size
def forward(self, pixel_values):
b, c, h, w = pixel_values.shape
img = pixel_values.movedim(1, -1).reshape(b, h // self.patch_size, self.patch_size, w // self.patch_size, self.patch_size, c)
img = img.permute(0, 1, 3, 2, 4, 5)
img = img.reshape(b, img.shape[1] * img.shape[2], -1)
img = self.patch_embedding(img)
return siglip2_pos_embed(self.position_embedding.weight, img, (h // self.patch_size, w // self.patch_size))
class CLIPVisionEmbeddings(torch.nn.Module): class CLIPVisionEmbeddings(torch.nn.Module):
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", dtype=None, device=None, operations=None): def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", dtype=None, device=None, operations=None):
@ -218,8 +273,11 @@ class CLIPVision(torch.nn.Module):
intermediate_activation = config_dict["hidden_act"] intermediate_activation = config_dict["hidden_act"]
model_type = config_dict["model_type"] model_type = config_dict["model_type"]
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations) if model_type in ["siglip2_vision_model"]:
if model_type == "siglip_vision_model": self.embeddings = Siglip2Embeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, num_patches=config_dict.get("num_patches", None), dtype=dtype, device=device, operations=operations)
else:
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
if model_type in ["siglip_vision_model", "siglip2_vision_model"]:
self.pre_layrnorm = lambda a: a self.pre_layrnorm = lambda a: a
self.output_layernorm = True self.output_layernorm = True
else: else:

View File

@ -21,6 +21,7 @@ clip_preprocess = comfy.clip_model.clip_preprocess # Prevent some stuff from br
IMAGE_ENCODERS = { IMAGE_ENCODERS = {
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection, "clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection, "siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
"siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection,
"dinov2": comfy.image_encoders.dino2.Dinov2Model, "dinov2": comfy.image_encoders.dino2.Dinov2Model,
} }
@ -32,9 +33,10 @@ class ClipVisionModel():
self.image_size = config.get("image_size", 224) self.image_size = config.get("image_size", 224)
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073]) self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711]) self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
model_type = config.get("model_type", "clip_vision_model") self.model_type = config.get("model_type", "clip_vision_model")
model_class = IMAGE_ENCODERS.get(model_type) self.config = config.copy()
if model_type == "siglip_vision_model": model_class = IMAGE_ENCODERS.get(self.model_type)
if self.model_type == "siglip_vision_model":
self.return_all_hidden_states = True self.return_all_hidden_states = True
else: else:
self.return_all_hidden_states = False self.return_all_hidden_states = False
@ -55,7 +57,10 @@ class ClipVisionModel():
def encode_image(self, image, crop=True): def encode_image(self, image, crop=True):
comfy.model_management.load_model_gpu(self.patcher) comfy.model_management.load_model_gpu(self.patcher)
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float() if self.model_type == "siglip2_vision_model":
pixel_values = comfy.clip_model.siglip2_preprocess(image.to(self.load_device), size=self.image_size, patch_size=self.config.get("patch_size", 16), num_patches=self.config.get("num_patches", 256), mean=self.image_mean, std=self.image_std, crop=crop).float()
else:
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2) out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
outputs = Output() outputs = Output()
@ -107,10 +112,14 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd: elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
embed_shape = sd["vision_model.embeddings.position_embedding.weight"].shape[0] embed_shape = sd["vision_model.embeddings.position_embedding.weight"].shape[0]
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152: if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
if embed_shape == 729: patch_embedding_shape = sd["vision_model.embeddings.patch_embedding.weight"].shape
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json") if len(patch_embedding_shape) == 2:
elif embed_shape == 1024: json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip2_base_naflex.json")
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json") else:
if embed_shape == 729:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
elif embed_shape == 1024:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json")
elif embed_shape == 577: elif embed_shape == 577:
if "multi_modal_projector.linear_1.bias" in sd: if "multi_modal_projector.linear_1.bias" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json") json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json")

View File

@ -0,0 +1,14 @@
{
"num_channels": 3,
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1152,
"image_size": -1,
"intermediate_size": 4304,
"model_type": "siglip2_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 27,
"patch_size": 16,
"num_patches": 256,
"image_mean": [0.5, 0.5, 0.5],
"image_std": [0.5, 0.5, 0.5]
}

View File

@ -65,3 +65,121 @@ def stochastic_rounding(value, dtype, seed=0):
return output return output
return value.to(dtype=dtype) return value.to(dtype=dtype)
# TODO: improve this?
def stochastic_float_to_fp4_e2m1(x, generator):
orig_shape = x.shape
sign = torch.signbit(x).to(torch.uint8)
exp = torch.floor(torch.log2(x.abs()) + 1.0).clamp(0, 3)
x += (torch.rand(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator) - 0.5) * (2 ** (exp - 2.0)) * 1.25
x = x.abs()
exp = torch.floor(torch.log2(x) + 1.1925).clamp(0, 3)
mantissa = torch.where(
exp > 0,
(x / (2.0 ** (exp - 1)) - 1.0) * 2.0,
(x * 2.0),
out=x
).round().to(torch.uint8)
del x
exp = exp.to(torch.uint8)
fp4 = (sign << 3) | (exp << 1) | mantissa
del sign, exp, mantissa
fp4_flat = fp4.view(-1)
packed = (fp4_flat[0::2] << 4) | fp4_flat[1::2]
return packed.reshape(list(orig_shape)[:-1] + [-1])
def to_blocked(input_matrix, flatten: bool = True) -> torch.Tensor:
"""
Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.
See:
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
Args:
input_matrix: Input tensor of shape (H, W)
Returns:
Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4))
"""
def ceil_div(a, b):
return (a + b - 1) // b
rows, cols = input_matrix.shape
n_row_blocks = ceil_div(rows, 128)
n_col_blocks = ceil_div(cols, 4)
# Calculate the padded shape
padded_rows = n_row_blocks * 128
padded_cols = n_col_blocks * 4
padded = input_matrix
if (rows, cols) != (padded_rows, padded_cols):
padded = torch.zeros(
(padded_rows, padded_cols),
device=input_matrix.device,
dtype=input_matrix.dtype,
)
padded[:rows, :cols] = input_matrix
# Rearrange the blocks
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
if flatten:
return rearranged.flatten()
return rearranged.reshape(padded_rows, padded_cols)
def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0):
F4_E2M1_MAX = 6.0
F8_E4M3_MAX = 448.0
def roundup(x: int, multiple: int) -> int:
"""Round up x to the nearest multiple."""
return ((x + multiple - 1) // multiple) * multiple
orig_shape = x.shape
# Handle padding
if pad_16x:
rows, cols = x.shape
padded_rows = roundup(rows, 16)
padded_cols = roundup(cols, 16)
if padded_rows != rows or padded_cols != cols:
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
# Note: We update orig_shape because the output tensor logic below assumes x.shape matches
# what we want to produce. If we pad here, we want the padded output.
orig_shape = x.shape
block_size = 16
x = x.reshape(orig_shape[0], -1, block_size)
max_abs = torch.amax(torch.abs(x), dim=-1)
block_scale = max_abs / F4_E2M1_MAX
scaled_block_scales = block_scale / per_tensor_scale.to(block_scale.dtype)
scaled_block_scales_fp8 = torch.clamp(scaled_block_scales, max=F8_E4M3_MAX).to(torch.float8_e4m3fn)
total_scale = per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)
# Handle zero blocks (from padding): avoid 0/0 NaN
zero_scale_mask = (total_scale == 0)
total_scale_safe = torch.where(zero_scale_mask, torch.ones_like(total_scale), total_scale)
x = x / total_scale_safe.unsqueeze(-1)
generator = torch.Generator(device=x.device)
generator.manual_seed(seed)
x = torch.where(zero_scale_mask.unsqueeze(-1), torch.zeros_like(x), x)
x = x.view(orig_shape)
data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator)
blocked_scales = to_blocked(scaled_block_scales_fp8, flatten=False)
return data_lp, blocked_scales

View File

@ -3,8 +3,8 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm
import model_management import comfy.model_management
import model_patcher import comfy.model_patcher
class SRResidualCausalBlock3D(nn.Module): class SRResidualCausalBlock3D(nn.Module):
def __init__(self, channels: int): def __init__(self, channels: int):
@ -103,13 +103,13 @@ UPSAMPLERS = {
class HunyuanVideo15SRModel(): class HunyuanVideo15SRModel():
def __init__(self, model_type, config): def __init__(self, model_type, config):
self.load_device = model_management.vae_device() self.load_device = comfy.model_management.vae_device()
offload_device = model_management.vae_offload_device() offload_device = comfy.model_management.vae_offload_device()
self.dtype = model_management.vae_dtype(self.load_device) self.dtype = comfy.model_management.vae_dtype(self.load_device)
self.model_class = UPSAMPLERS.get(model_type) self.model_class = UPSAMPLERS.get(model_type)
self.model = self.model_class(**config).eval() self.model = self.model_class(**config).eval()
self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd): def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=True) return self.model.load_state_dict(sd, strict=True)
@ -118,5 +118,5 @@ class HunyuanVideo15SRModel():
return self.model.state_dict() return self.model.state_dict()
def resample_latent(self, latent): def resample_latent(self, latent):
model_management.load_model_gpu(self.patcher) comfy.model_management.load_model_gpu(self.patcher)
return self.model(latent.to(self.load_device)) return self.model(latent.to(self.load_device))

View File

@ -11,6 +11,69 @@ from comfy.ldm.lightricks.model import (
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
import comfy.ldm.common_dit import comfy.ldm.common_dit
class CompressedTimestep:
"""Store video timestep embeddings in compressed form using per-frame indexing."""
__slots__ = ('data', 'batch_size', 'num_frames', 'patches_per_frame', 'feature_dim')
def __init__(self, tensor: torch.Tensor, patches_per_frame: int):
"""
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame
patches_per_frame: Number of spatial patches per frame (height * width in latent space)
"""
self.batch_size, num_tokens, self.feature_dim = tensor.shape
# Check if compression is valid (num_tokens must be divisible by patches_per_frame)
if num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
self.patches_per_frame = patches_per_frame
self.num_frames = num_tokens // patches_per_frame
# Reshape to [batch, frames, patches_per_frame, feature_dim] and store one value per frame
# All patches in a frame are identical, so we only keep the first one
reshaped = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim)
self.data = reshaped[:, :, 0, :].contiguous() # [batch, frames, feature_dim]
else:
# Not divisible or too small - store directly without compression
self.patches_per_frame = 1
self.num_frames = num_tokens
self.data = tensor
def expand(self):
"""Expand back to original tensor."""
if self.patches_per_frame == 1:
return self.data
# [batch, frames, feature_dim] -> [batch, frames, patches_per_frame, feature_dim] -> [batch, tokens, feature_dim]
expanded = self.data.unsqueeze(2).expand(self.batch_size, self.num_frames, self.patches_per_frame, self.feature_dim)
return expanded.reshape(self.batch_size, -1, self.feature_dim)
def expand_for_computation(self, scale_shift_table: torch.Tensor, batch_size: int, indices: slice = slice(None, None)):
"""Compute ada values on compressed per-frame data, then expand spatially."""
num_ada_params = scale_shift_table.shape[0]
# No compression - compute directly
if self.patches_per_frame == 1:
num_tokens = self.data.shape[1]
dim_per_param = self.feature_dim // num_ada_params
reshaped = self.data.reshape(batch_size, num_tokens, num_ada_params, dim_per_param)[:, :, indices, :]
table_values = scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=self.data.device, dtype=self.data.dtype)
ada_values = (table_values + reshaped).unbind(dim=2)
return ada_values
# Compressed: compute on per-frame data then expand spatially
# Reshape: [batch, frames, feature_dim] -> [batch, frames, num_ada_params, dim_per_param]
frame_reshaped = self.data.reshape(batch_size, self.num_frames, num_ada_params, -1)[:, :, indices, :]
table_values = scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(
device=self.data.device, dtype=self.data.dtype
)
frame_ada = (table_values + frame_reshaped).unbind(dim=2)
# Expand each ada parameter spatially: [batch, frames, dim] -> [batch, frames, patches, dim] -> [batch, tokens, dim]
return tuple(
frame_val.unsqueeze(2).expand(batch_size, self.num_frames, self.patches_per_frame, -1)
.reshape(batch_size, -1, frame_val.shape[-1])
for frame_val in frame_ada
)
class BasicAVTransformerBlock(nn.Module): class BasicAVTransformerBlock(nn.Module):
def __init__( def __init__(
self, self,
@ -119,6 +182,9 @@ class BasicAVTransformerBlock(nn.Module):
def get_ada_values( def get_ada_values(
self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice = slice(None, None) self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice = slice(None, None)
): ):
if isinstance(timestep, CompressedTimestep):
return timestep.expand_for_computation(scale_shift_table, batch_size, indices)
num_ada_params = scale_shift_table.shape[0] num_ada_params = scale_shift_table.shape[0]
ada_values = ( ada_values = (
@ -146,10 +212,7 @@ class BasicAVTransformerBlock(nn.Module):
gate_timestep, gate_timestep,
) )
scale_shift_chunks = [t.squeeze(2) for t in scale_shift_ada_values] return (*scale_shift_ada_values, *gate_ada_values)
gate_ada_values = [t.squeeze(2) for t in gate_ada_values]
return (*scale_shift_chunks, *gate_ada_values)
def forward( def forward(
self, self,
@ -543,72 +606,80 @@ class LTXAVModel(LTXVModel):
if grid_mask is not None: if grid_mask is not None:
timestep = timestep[:, grid_mask] timestep = timestep[:, grid_mask]
timestep = timestep * self.timestep_scale_multiplier timestep_scaled = timestep * self.timestep_scale_multiplier
v_timestep, v_embedded_timestep = self.adaln_single( v_timestep, v_embedded_timestep = self.adaln_single(
timestep.flatten(), timestep_scaled.flatten(),
{"resolution": None, "aspect_ratio": None}, {"resolution": None, "aspect_ratio": None},
batch_size=batch_size, batch_size=batch_size,
hidden_dtype=hidden_dtype, hidden_dtype=hidden_dtype,
) )
# Second dimension is 1 or number of tokens (if timestep_per_token) # Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
v_timestep = v_timestep.view(batch_size, -1, v_timestep.shape[-1]) # Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
v_embedded_timestep = v_embedded_timestep.view( orig_shape = kwargs.get("orig_shape")
batch_size, -1, v_embedded_timestep.shape[-1] v_patches_per_frame = None
) if orig_shape is not None and len(orig_shape) == 5:
# orig_shape[3] = height, orig_shape[4] = width (in latent space)
v_patches_per_frame = orig_shape[3] * orig_shape[4]
# Reshape to [batch_size, num_tokens, dim] and compress for storage
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame)
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)
# Prepare audio timestep # Prepare audio timestep
a_timestep = kwargs.get("a_timestep") a_timestep = kwargs.get("a_timestep")
if a_timestep is not None: if a_timestep is not None:
a_timestep = a_timestep * self.timestep_scale_multiplier a_timestep_scaled = a_timestep * self.timestep_scale_multiplier
a_timestep_flat = a_timestep_scaled.flatten()
timestep_flat = timestep_scaled.flatten()
av_ca_factor = self.av_ca_timestep_scale_multiplier / self.timestep_scale_multiplier av_ca_factor = self.av_ca_timestep_scale_multiplier / self.timestep_scale_multiplier
# Cross-attention timesteps - compress these too
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single( av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
a_timestep.flatten(), a_timestep_flat,
{"resolution": None, "aspect_ratio": None}, {"resolution": None, "aspect_ratio": None},
batch_size=batch_size, batch_size=batch_size,
hidden_dtype=hidden_dtype, hidden_dtype=hidden_dtype,
) )
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single( av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
timestep.flatten(), timestep_flat,
{"resolution": None, "aspect_ratio": None}, {"resolution": None, "aspect_ratio": None},
batch_size=batch_size, batch_size=batch_size,
hidden_dtype=hidden_dtype, hidden_dtype=hidden_dtype,
) )
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single( av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
timestep.flatten() * av_ca_factor, timestep_flat * av_ca_factor,
{"resolution": None, "aspect_ratio": None}, {"resolution": None, "aspect_ratio": None},
batch_size=batch_size, batch_size=batch_size,
hidden_dtype=hidden_dtype, hidden_dtype=hidden_dtype,
) )
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single( av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
a_timestep.flatten() * av_ca_factor, a_timestep_flat * av_ca_factor,
{"resolution": None, "aspect_ratio": None}, {"resolution": None, "aspect_ratio": None},
batch_size=batch_size, batch_size=batch_size,
hidden_dtype=hidden_dtype, hidden_dtype=hidden_dtype,
) )
# Compress cross-attention timesteps (only video side, audio is too small to benefit)
cross_av_timestep_ss = [
av_ca_audio_scale_shift_timestep.view(batch_size, -1, av_ca_audio_scale_shift_timestep.shape[-1]),
CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed
CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed
av_ca_v2a_gate_noise_timestep.view(batch_size, -1, av_ca_v2a_gate_noise_timestep.shape[-1]),
]
a_timestep, a_embedded_timestep = self.audio_adaln_single( a_timestep, a_embedded_timestep = self.audio_adaln_single(
a_timestep.flatten(), a_timestep_flat,
{"resolution": None, "aspect_ratio": None}, {"resolution": None, "aspect_ratio": None},
batch_size=batch_size, batch_size=batch_size,
hidden_dtype=hidden_dtype, hidden_dtype=hidden_dtype,
) )
# Audio timesteps
a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1]) a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1])
a_embedded_timestep = a_embedded_timestep.view( a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1])
batch_size, -1, a_embedded_timestep.shape[-1]
)
cross_av_timestep_ss = [
av_ca_audio_scale_shift_timestep,
av_ca_video_scale_shift_timestep,
av_ca_a2v_gate_noise_timestep,
av_ca_v2a_gate_noise_timestep,
]
cross_av_timestep_ss = list(
[t.view(batch_size, -1, t.shape[-1]) for t in cross_av_timestep_ss]
)
else: else:
a_timestep = timestep a_timestep = timestep_scaled
a_embedded_timestep = kwargs.get("embedded_timestep") a_embedded_timestep = kwargs.get("embedded_timestep")
cross_av_timestep_ss = [] cross_av_timestep_ss = []
@ -767,6 +838,11 @@ class LTXAVModel(LTXVModel):
ax = x[1] ax = x[1]
v_embedded_timestep = embedded_timestep[0] v_embedded_timestep = embedded_timestep[0]
a_embedded_timestep = embedded_timestep[1] a_embedded_timestep = embedded_timestep[1]
# Expand compressed video timestep if needed
if isinstance(v_embedded_timestep, CompressedTimestep):
v_embedded_timestep = v_embedded_timestep.expand()
vx = super()._process_output(vx, v_embedded_timestep, keyframe_idxs, **kwargs) vx = super()._process_output(vx, v_embedded_timestep, keyframe_idxs, **kwargs)
# Process audio output # Process audio output

View File

@ -322,6 +322,7 @@ def model_lora_keys_unet(model, key_map={}):
key_map["diffusion_model.{}".format(key_lora)] = to key_map["diffusion_model.{}".format(key_lora)] = to
key_map["transformer.{}".format(key_lora)] = to key_map["transformer.{}".format(key_lora)] = to
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
key_map[key_lora] = to
if isinstance(model, comfy.model_base.Kandinsky5): if isinstance(model, comfy.model_base.Kandinsky5):
for k in sdk: for k in sdk:

View File

@ -237,6 +237,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
else: else:
dit_config["vec_in_dim"] = None dit_config["vec_in_dim"] = None
dit_config["num_heads"] = dit_config["hidden_size"] // sum(dit_config["axes_dim"])
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.') dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.') dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma

View File

@ -22,7 +22,6 @@ from enum import Enum
from comfy.cli_args import args, PerformanceFeature from comfy.cli_args import args, PerformanceFeature
import torch import torch
import sys import sys
import importlib
import platform import platform
import weakref import weakref
import gc import gc
@ -349,15 +348,27 @@ try:
except: except:
rocm_version = (6, -1) rocm_version = (6, -1)
def aotriton_supported(gpu_arch):
path = torch.__path__[0]
path = os.path.join(os.path.join(path, "lib"), "aotriton.images")
gfx = set(map(lambda a: a[4:], filter(lambda a: a.startswith("amd-gfx"), os.listdir(path))))
if gpu_arch in gfx:
return True
if "{}x".format(gpu_arch[:-1]) in gfx:
return True
if "{}xx".format(gpu_arch[:-2]) in gfx:
return True
return False
logging.info("AMD arch: {}".format(arch)) logging.info("AMD arch: {}".format(arch))
logging.info("ROCm version: {}".format(rocm_version)) logging.info("ROCm version: {}".format(rocm_version))
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
if importlib.util.find_spec('triton') is not None: # AMD efficient attention implementation depends on triton. TODO: better way of detecting if it's compiled in or not. if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton.
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
if rocm_version >= (7, 0): if rocm_version >= (7, 0):
if any((a in arch) for a in ["gfx1201"]): if any((a in arch) for a in ["gfx1200", "gfx1201"]):
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4): if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx950"]): # TODO: more arches, "gfx942" gives error on pytorch nightly 2.10 1013 rocm7.0 if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx950"]): # TODO: more arches, "gfx942" gives error on pytorch nightly 2.10 1013 rocm7.0

View File

@ -718,6 +718,7 @@ class ModelPatcher:
continue continue
cast_weight = self.force_cast_weights cast_weight = self.force_cast_weights
m.comfy_force_cast_weights = self.force_cast_weights
if lowvram_weight: if lowvram_weight:
if hasattr(m, "comfy_cast_weights"): if hasattr(m, "comfy_cast_weights"):
m.weight_function = [] m.weight_function = []
@ -790,11 +791,12 @@ class ModelPatcher:
for param in params: for param in params:
self.pin_weight_to_device("{}.{}".format(n, param)) self.pin_weight_to_device("{}.{}".format(n, param))
usable_stat = "{:.2f} MB usable,".format(lowvram_model_memory / (1024 * 1024)) if lowvram_model_memory < 1e32 else ""
if lowvram_counter > 0: if lowvram_counter > 0:
logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter)) logging.info("loaded partially; {} {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(usable_stat, mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter))
self.model.model_lowvram = True self.model.model_lowvram = True
else: else:
logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) logging.info("loaded completely; {} {:.2f} MB loaded, full load: {}".format(usable_stat, mem_counter / (1024 * 1024), full_load))
self.model.model_lowvram = False self.model.model_lowvram = False
if full_load: if full_load:
self.model.to(device_to) self.model.to(device_to)

View File

@ -546,7 +546,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
weight_key = f"{prefix}weight" weight_key = f"{prefix}weight"
weight = state_dict.pop(weight_key, None) weight = state_dict.pop(weight_key, None)
if weight is None: if weight is None:
raise ValueError(f"Missing weight for layer {layer_name}") logging.warning(f"Missing weight for layer {layer_name}")
return
manually_loaded_keys = [weight_key] manually_loaded_keys = [weight_key]
@ -624,21 +625,29 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
missing_keys.remove(key) missing_keys.remove(key)
def state_dict(self, *args, destination=None, prefix="", **kwargs): def state_dict(self, *args, destination=None, prefix="", **kwargs):
sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs) if destination is not None:
if isinstance(self.weight, QuantizedTensor): sd = destination
layout_cls = self.weight._layout_cls else:
sd = {}
# Check if it's any FP8 variant (E4M3 or E5M2) if self.bias is not None:
if layout_cls in ("TensorCoreFP8E4M3Layout", "TensorCoreFP8E5M2Layout", "TensorCoreFP8Layout"): sd["{}bias".format(prefix)] = self.bias
sd["{}weight_scale".format(prefix)] = self.weight._params.scale
elif layout_cls == "TensorCoreNVFP4Layout": if isinstance(self.weight, QuantizedTensor):
sd["{}weight_scale_2".format(prefix)] = self.weight._params.scale sd_out = self.weight.state_dict("{}weight".format(prefix))
sd["{}weight_scale".format(prefix)] = self.weight._params.block_scale for k in sd_out:
sd[k] = sd_out[k]
quant_conf = {"format": self.quant_format} quant_conf = {"format": self.quant_format}
if self._full_precision_mm_config: if self._full_precision_mm_config:
quant_conf["full_precision_matrix_mult"] = True quant_conf["full_precision_matrix_mult"] = True
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8) sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
input_scale = getattr(self, 'input_scale', None)
if input_scale is not None:
sd["{}input_scale".format(prefix)] = input_scale
else:
sd["{}weight".format(prefix)] = self.weight
return sd return sd
def _forward(self, input, weight, bias): def _forward(self, input, weight, bias):
@ -654,29 +663,29 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
run_every_op() run_every_op()
input_shape = input.shape input_shape = input.shape
tensor_3d = input.ndim == 3 reshaped_3d = False
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(input, *args, **kwargs)
if (getattr(self, 'layout_type', None) is not None and if (getattr(self, 'layout_type', None) is not None and
not isinstance(input, QuantizedTensor)): not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
not getattr(self, 'comfy_force_cast_weights', False) and
len(self.weight_function) == 0 and len(self.bias_function) == 0):
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others) # Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
if tensor_3d: input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
input = input.reshape(-1, input_shape[2])
if input.ndim != 2: # Fall back to non-quantized for non-2D tensors
# Fall back to comfy_cast_weights for non-2D tensors if input_reshaped.ndim == 2:
return self.forward_comfy_cast_weights(input.reshape(input_shape), *args, **kwargs) reshaped_3d = input.ndim == 3
# dtype is now implicit in the layout class
scale = getattr(self, 'input_scale', None)
if scale is not None:
scale = comfy.model_management.cast_to_device(scale, input.device, None)
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
# dtype is now implicit in the layout class output = self.forward_comfy_cast_weights(input)
input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None))
output = self._forward(input, self.weight, self.bias)
# Reshape output back to 3D if input was 3D # Reshape output back to 3D if input was 3D
if tensor_3d: if reshaped_3d:
output = output.reshape((input_shape[0], input_shape[1], self.weight.shape[0])) output = output.reshape((input_shape[0], input_shape[1], self.weight.shape[0]))
return output return output
@ -690,7 +699,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
if getattr(self, 'layout_type', None) is not None: if getattr(self, 'layout_type', None) is not None:
# dtype is now implicit in the layout class # dtype is now implicit in the layout class
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True) weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype)
else: else:
weight = weight.to(self.weight.dtype) weight = weight.to(self.weight.dtype)
if return_weight: if return_weight:

View File

@ -7,7 +7,7 @@ try:
QuantizedTensor, QuantizedTensor,
QuantizedLayout, QuantizedLayout,
TensorCoreFP8Layout as _CKFp8Layout, TensorCoreFP8Layout as _CKFp8Layout,
TensorCoreNVFP4Layout, # Direct import, no wrapper needed TensorCoreNVFP4Layout as _CKNvfp4Layout,
register_layout_op, register_layout_op,
register_layout_class, register_layout_class,
get_layout_class, get_layout_class,
@ -19,6 +19,7 @@ try:
cuda_version = tuple(map(int, str(torch.version.cuda).split('.'))) cuda_version = tuple(map(int, str(torch.version.cuda).split('.')))
if cuda_version < (13,): if cuda_version < (13,):
ck.registry.disable("cuda") ck.registry.disable("cuda")
logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.")
ck.registry.disable("triton") ck.registry.disable("triton")
for k, v in ck.list_backends().items(): for k, v in ck.list_backends().items():
@ -33,7 +34,7 @@ except ImportError as e:
class _CKFp8Layout: class _CKFp8Layout:
pass pass
class TensorCoreNVFP4Layout: class _CKNvfp4Layout:
pass pass
def register_layout_class(name, cls): def register_layout_class(name, cls):
@ -83,6 +84,39 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout):
return qdata, params return qdata, params
class TensorCoreNVFP4Layout(_CKNvfp4Layout):
@classmethod
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
if tensor.dim() != 2:
raise ValueError(f"NVFP4 requires 2D tensor, got {tensor.dim()}D")
orig_dtype = tensor.dtype
orig_shape = tuple(tensor.shape)
if scale is None or (isinstance(scale, str) and scale == "recalculate"):
scale = torch.amax(tensor.abs()) / (ck.float_utils.F8_E4M3_MAX * ck.float_utils.F4_E2M1_MAX)
if not isinstance(scale, torch.Tensor):
scale = torch.tensor(scale)
scale = scale.to(device=tensor.device, dtype=torch.float32)
padded_shape = cls.get_padded_shape(orig_shape)
needs_padding = padded_shape != orig_shape
if stochastic_rounding > 0:
qdata, block_scale = comfy.float.stochastic_round_quantize_nvfp4(tensor, scale, pad_16x=needs_padding, seed=stochastic_rounding)
else:
qdata, block_scale = ck.quantize_nvfp4(tensor, scale, pad_16x=needs_padding)
params = cls.Params(
scale=scale,
orig_dtype=orig_dtype,
orig_shape=orig_shape,
block_scale=block_scale,
)
return qdata, params
class TensorCoreFP8E4M3Layout(_TensorCoreFP8LayoutBase): class TensorCoreFP8E4M3Layout(_TensorCoreFP8LayoutBase):
FP8_DTYPE = torch.float8_e4m3fn FP8_DTYPE = torch.float8_e4m3fn

View File

@ -218,7 +218,7 @@ class CLIP:
if unprojected: if unprojected:
self.cond_stage_model.set_clip_options({"projected_pooled": False}) self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model() self.load_model(tokens)
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
all_hooks.reset() all_hooks.reset()
self.patcher.patch_hooks(None) self.patcher.patch_hooks(None)
@ -266,7 +266,7 @@ class CLIP:
if return_pooled == "unprojected": if return_pooled == "unprojected":
self.cond_stage_model.set_clip_options({"projected_pooled": False}) self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model() self.load_model(tokens)
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
o = self.cond_stage_model.encode_token_weights(tokens) o = self.cond_stage_model.encode_token_weights(tokens)
cond, pooled = o[:2] cond, pooled = o[:2]
@ -299,8 +299,11 @@ class CLIP:
sd_clip[k] = sd_tokenizer[k] sd_clip[k] = sd_tokenizer[k]
return sd_clip return sd_clip
def load_model(self): def load_model(self, tokens={}):
model_management.load_model_gpu(self.patcher) memory_used = 0
if hasattr(self.cond_stage_model, "memory_estimation_function"):
memory_used = self.cond_stage_model.memory_estimation_function(tokens, device=self.patcher.load_device)
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
return self.patcher return self.patcher
def get_key_patches(self): def get_key_patches(self):
@ -476,8 +479,8 @@ class VAE:
self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE(version=version, config=vae_config) self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE(version=version, config=vae_config)
self.latent_channels = 128 self.latent_channels = 128
self.latent_dim = 3 self.latent_dim = 3
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (1200 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: (80 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32) self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32)
self.upscale_index_formula = (8, 32, 32) self.upscale_index_formula = (8, 32, 32)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32) self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
@ -1056,9 +1059,9 @@ def detect_te_model(sd):
return TEModel.JINA_CLIP_2 return TEModel.JINA_CLIP_2
if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd: if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"] weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
if weight.shape[-1] == 4096: if weight.shape[0] == 10240:
return TEModel.T5_XXL return TEModel.T5_XXL
elif weight.shape[-1] == 2048: elif weight.shape[0] == 5120:
return TEModel.T5_XL return TEModel.T5_XL
if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd: if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd:
return TEModel.T5_XXL_OLD return TEModel.T5_XXL_OLD

View File

@ -845,7 +845,7 @@ class LTXAV(LTXV):
def __init__(self, unet_config): def __init__(self, unet_config):
super().__init__(unet_config) super().__init__(unet_config)
self.memory_usage_factor = 0.055 # TODO self.memory_usage_factor = 0.077 # TODO
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
out = model_base.LTXAV(self, device=device) out = model_base.LTXAV(self, device=device)

View File

@ -36,7 +36,7 @@ def te(dtype_t5=None, t5_quantization_metadata=None):
if t5_quantization_metadata is not None: if t5_quantization_metadata is not None:
model_options = model_options.copy() model_options = model_options.copy()
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
if dtype is None: if dtype_t5 is not None:
dtype = dtype_t5 dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options) super().__init__(device=device, dtype=dtype, model_options=model_options)
return CosmosTEModel_ return CosmosTEModel_

View File

@ -32,7 +32,7 @@ def mochi_te(dtype_t5=None, t5_quantization_metadata=None):
if t5_quantization_metadata is not None: if t5_quantization_metadata is not None:
model_options = model_options.copy() model_options = model_options.copy()
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
if dtype is None: if dtype_t5 is not None:
dtype = dtype_t5 dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options) super().__init__(device=device, dtype=dtype, model_options=model_options)
return MochiTEModel_ return MochiTEModel_

View File

@ -36,10 +36,10 @@ class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer):
class Gemma3_12BModel(sd1_clip.SDClipModel): class Gemma3_12BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}): def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
llama_scaled_fp8 = model_options.get("gemma_scaled_fp8", None) llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
if llama_scaled_fp8 is not None: if llama_quantization_metadata is not None:
model_options = model_options.copy() model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8 model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
@ -98,10 +98,13 @@ class LTXAVTEModel(torch.nn.Module):
out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs) out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs)
out_device = out.device out_device = out.device
if comfy.model_management.should_use_bf16(self.execution_device):
out = out.to(device=self.execution_device, dtype=torch.bfloat16)
out = out.movedim(1, -1).to(self.execution_device) out = out.movedim(1, -1).to(self.execution_device)
out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6) out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6)
out = out.reshape((out.shape[0], out.shape[1], -1)) out = out.reshape((out.shape[0], out.shape[1], -1))
out = self.text_embedding_projection(out) out = self.text_embedding_projection(out)
out = out.float()
out_vid = self.video_embeddings_connector(out)[0] out_vid = self.video_embeddings_connector(out)[0]
out_audio = self.audio_embeddings_connector(out)[0] out_audio = self.audio_embeddings_connector(out)[0]
out = torch.concat((out_vid, out_audio), dim=-1) out = torch.concat((out_vid, out_audio), dim=-1)
@ -118,13 +121,21 @@ class LTXAVTEModel(torch.nn.Module):
return self.load_state_dict(sdo, strict=False) return self.load_state_dict(sdo, strict=False)
def memory_estimation_function(self, token_weight_pairs, device=None):
constant = 6.0
if comfy.model_management.should_use_bf16(device):
constant /= 2.0
def ltxav_te(dtype_llama=None, llama_scaled_fp8=None): token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
return num_tokens * constant * 1024 * 1024
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
class LTXAVTEModel_(LTXAVTEModel): class LTXAVTEModel_(LTXAVTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}): def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options: if llama_quantization_metadata is not None:
model_options = model_options.copy() model_options = model_options.copy()
model_options["llama_scaled_fp8"] = llama_scaled_fp8 model_options["llama_quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None: if dtype_llama is not None:
dtype = dtype_llama dtype = dtype_llama
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options) super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)

View File

@ -36,7 +36,7 @@ def pixart_te(dtype_t5=None, t5_quantization_metadata=None):
if t5_quantization_metadata is not None: if t5_quantization_metadata is not None:
model_options = model_options.copy() model_options = model_options.copy()
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
if dtype is None: if dtype_t5 is not None:
dtype = dtype_t5 dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options) super().__init__(device=device, dtype=dtype, model_options=model_options)
return PixArtTEModel_ return PixArtTEModel_

View File

@ -1113,6 +1113,18 @@ class DynamicSlot(ComfyTypeI):
out_dict[input_type][finalized_id] = value out_dict[input_type][finalized_id] = value
out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1]) out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1])
@comfytype(io_type="IMAGECOMPARE")
class ImageCompare(ComfyTypeI):
Type = dict
class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
socketless: bool=True):
super().__init__(id, display_name, optional, tooltip, None, None, socketless)
def as_dict(self):
return super().as_dict()
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {} DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]): def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
DYNAMIC_INPUT_LOOKUP[io_type] = func DYNAMIC_INPUT_LOOKUP[io_type] = func
@ -1958,4 +1970,5 @@ __all__ = [
"add_to_dict_v1", "add_to_dict_v1",
"add_to_dict_v3", "add_to_dict_v3",
"V3Data", "V3Data",
"ImageCompare",
] ]

View File

@ -0,0 +1,41 @@
from pydantic import BaseModel, Field
class SubjectReference(BaseModel):
id: str = Field(...)
images: list[str] = Field(...)
class TaskCreationRequest(BaseModel):
model: str = Field(...)
prompt: str = Field(..., max_length=2000)
duration: int = Field(...)
seed: int = Field(..., ge=0, le=2147483647)
aspect_ratio: str | None = Field(None)
resolution: str | None = Field(None)
movement_amplitude: str | None = Field(None)
images: list[str] | None = Field(None, description="Base64 encoded string or image URL")
subjects: list[SubjectReference] | None = Field(None)
bgm: bool | None = Field(None)
audio: bool | None = Field(None)
class TaskCreationResponse(BaseModel):
task_id: str = Field(...)
state: str = Field(...)
created_at: str = Field(...)
code: int | None = Field(None, description="Error code")
class TaskResult(BaseModel):
id: str = Field(..., description="Creation id")
url: str = Field(..., description="The URL of the generated results, valid for one hour")
cover_url: str = Field(..., description="The cover URL of the generated results, valid for one hour")
class TaskStatusResponse(BaseModel):
state: str = Field(...)
err_code: str | None = Field(None)
progress: float | None = Field(None)
credits: int | None = Field(None)
creations: list[TaskResult] = Field(..., description="Generated results")

View File

@ -567,7 +567,7 @@ async def execute_lipsync(
# Upload the audio file to Comfy API and get download URL # Upload the audio file to Comfy API and get download URL
if audio: if audio:
audio_url = await upload_audio_to_comfyapi( audio_url = await upload_audio_to_comfyapi(
cls, audio, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mpeg", filename="output.mp3" cls, audio, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mpeg"
) )
logging.info("Uploaded audio to Comfy API. URL: %s", audio_url) logging.info("Uploaded audio to Comfy API. URL: %s", audio_url)
else: else:

View File

@ -2,7 +2,6 @@ import builtins
from io import BytesIO from io import BytesIO
import aiohttp import aiohttp
import torch
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input from comfy_api.latest import IO, ComfyExtension, Input
@ -138,7 +137,7 @@ class TopazImageEnhance(IO.ComfyNode):
async def execute( async def execute(
cls, cls,
model: str, model: str,
image: torch.Tensor, image: Input.Image,
prompt: str = "", prompt: str = "",
subject_detection: str = "All", subject_detection: str = "All",
face_enhancement: bool = True, face_enhancement: bool = True,
@ -153,7 +152,9 @@ class TopazImageEnhance(IO.ComfyNode):
) -> IO.NodeOutput: ) -> IO.NodeOutput:
if get_number_of_images(image) != 1: if get_number_of_images(image) != 1:
raise ValueError("Only one input image is supported.") raise ValueError("Only one input image is supported.")
download_url = await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png") download_url = await upload_images_to_comfyapi(
cls, image, max_images=1, mime_type="image/png", total_pixels=4096*4096
)
initial_response = await sync_op( initial_response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/topaz/image/v1/enhance-gen/async", method="POST"), ApiEndpoint(path="/proxy/topaz/image/v1/enhance-gen/async", method="POST"),

View File

@ -1,12 +1,13 @@
import logging
from enum import Enum
from typing import Literal, Optional, TypeVar
import torch
from pydantic import BaseModel, Field
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.vidu import (
SubjectReference,
TaskCreationRequest,
TaskCreationResponse,
TaskResult,
TaskStatusResponse,
)
from comfy_api_nodes.util import ( from comfy_api_nodes.util import (
ApiEndpoint, ApiEndpoint,
download_url_to_video_output, download_url_to_video_output,
@ -17,6 +18,7 @@ from comfy_api_nodes.util import (
validate_image_aspect_ratio, validate_image_aspect_ratio,
validate_image_dimensions, validate_image_dimensions,
validate_images_aspect_ratio_closeness, validate_images_aspect_ratio_closeness,
validate_string,
) )
VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video" VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video"
@ -25,98 +27,33 @@ VIDU_REFERENCE_VIDEO = "/proxy/vidu/reference2video"
VIDU_START_END_VIDEO = "/proxy/vidu/start-end2video" VIDU_START_END_VIDEO = "/proxy/vidu/start-end2video"
VIDU_GET_GENERATION_STATUS = "/proxy/vidu/tasks/%s/creations" VIDU_GET_GENERATION_STATUS = "/proxy/vidu/tasks/%s/creations"
R = TypeVar("R")
class VideoModelName(str, Enum):
vidu_q1 = "viduq1"
class AspectRatio(str, Enum):
r_16_9 = "16:9"
r_9_16 = "9:16"
r_1_1 = "1:1"
class Resolution(str, Enum):
r_1080p = "1080p"
class MovementAmplitude(str, Enum):
auto = "auto"
small = "small"
medium = "medium"
large = "large"
class TaskCreationRequest(BaseModel):
model: VideoModelName = VideoModelName.vidu_q1
prompt: Optional[str] = Field(None, max_length=1500)
duration: Optional[Literal[5]] = 5
seed: Optional[int] = Field(0, ge=0, le=2147483647)
aspect_ratio: Optional[AspectRatio] = AspectRatio.r_16_9
resolution: Optional[Resolution] = Resolution.r_1080p
movement_amplitude: Optional[MovementAmplitude] = MovementAmplitude.auto
images: Optional[list[str]] = Field(None, description="Base64 encoded string or image URL")
class TaskCreationResponse(BaseModel):
task_id: str = Field(...)
state: str = Field(...)
created_at: str = Field(...)
code: Optional[int] = Field(None, description="Error code")
class TaskResult(BaseModel):
id: str = Field(..., description="Creation id")
url: str = Field(..., description="The URL of the generated results, valid for one hour")
cover_url: str = Field(..., description="The cover URL of the generated results, valid for one hour")
class TaskStatusResponse(BaseModel):
state: str = Field(...)
err_code: Optional[str] = Field(None)
creations: list[TaskResult] = Field(..., description="Generated results")
def get_video_url_from_response(response) -> Optional[str]:
if response.creations:
return response.creations[0].url
return None
def get_video_from_response(response) -> TaskResult:
if not response.creations:
error_msg = f"Vidu request does not contain results. State: {response.state}, Error Code: {response.err_code}"
logging.info(error_msg)
raise RuntimeError(error_msg)
logging.info("Vidu task %s succeeded. Video URL: %s", response.creations[0].id, response.creations[0].url)
return response.creations[0]
async def execute_task( async def execute_task(
cls: type[IO.ComfyNode], cls: type[IO.ComfyNode],
vidu_endpoint: str, vidu_endpoint: str,
payload: TaskCreationRequest, payload: TaskCreationRequest,
estimated_duration: int, ) -> list[TaskResult]:
) -> R: task_creation_response = await sync_op(
response = await sync_op(
cls, cls,
endpoint=ApiEndpoint(path=vidu_endpoint, method="POST"), endpoint=ApiEndpoint(path=vidu_endpoint, method="POST"),
response_model=TaskCreationResponse, response_model=TaskCreationResponse,
data=payload, data=payload,
) )
if response.state == "failed": if task_creation_response.state == "failed":
error_msg = f"Vidu request failed. Code: {response.code}" raise RuntimeError(f"Vidu request failed. Code: {task_creation_response.code}")
logging.error(error_msg) response = await poll_op(
raise RuntimeError(error_msg)
return await poll_op(
cls, cls,
ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % response.task_id), ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % task_creation_response.task_id),
response_model=TaskStatusResponse, response_model=TaskStatusResponse,
status_extractor=lambda r: r.state, status_extractor=lambda r: r.state,
estimated_duration=estimated_duration, progress_extractor=lambda r: r.progress,
max_poll_attempts=320,
) )
if not response.creations:
raise RuntimeError(
f"Vidu request does not contain results. State: {response.state}, Error Code: {response.err_code}"
)
return response.creations
class ViduTextToVideoNode(IO.ComfyNode): class ViduTextToVideoNode(IO.ComfyNode):
@ -127,14 +64,9 @@ class ViduTextToVideoNode(IO.ComfyNode):
node_id="ViduTextToVideoNode", node_id="ViduTextToVideoNode",
display_name="Vidu Text To Video Generation", display_name="Vidu Text To Video Generation",
category="api node/video/Vidu", category="api node/video/Vidu",
description="Generate video from text prompt", description="Generate video from a text prompt",
inputs=[ inputs=[
IO.Combo.Input( IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
"model",
options=VideoModelName,
default=VideoModelName.vidu_q1,
tooltip="Model name",
),
IO.String.Input( IO.String.Input(
"prompt", "prompt",
multiline=True, multiline=True,
@ -163,22 +95,19 @@ class ViduTextToVideoNode(IO.ComfyNode):
), ),
IO.Combo.Input( IO.Combo.Input(
"aspect_ratio", "aspect_ratio",
options=AspectRatio, options=["16:9", "9:16", "1:1"],
default=AspectRatio.r_16_9,
tooltip="The aspect ratio of the output video", tooltip="The aspect ratio of the output video",
optional=True, optional=True,
), ),
IO.Combo.Input( IO.Combo.Input(
"resolution", "resolution",
options=Resolution, options=["1080p"],
default=Resolution.r_1080p,
tooltip="Supported values may vary by model & duration", tooltip="Supported values may vary by model & duration",
optional=True, optional=True,
), ),
IO.Combo.Input( IO.Combo.Input(
"movement_amplitude", "movement_amplitude",
options=MovementAmplitude, options=["auto", "small", "medium", "large"],
default=MovementAmplitude.auto,
tooltip="The movement amplitude of objects in the frame", tooltip="The movement amplitude of objects in the frame",
optional=True, optional=True,
), ),
@ -208,7 +137,7 @@ class ViduTextToVideoNode(IO.ComfyNode):
if not prompt: if not prompt:
raise ValueError("The prompt field is required and cannot be empty.") raise ValueError("The prompt field is required and cannot be empty.")
payload = TaskCreationRequest( payload = TaskCreationRequest(
model_name=model, model=model,
prompt=prompt, prompt=prompt,
duration=duration, duration=duration,
seed=seed, seed=seed,
@ -216,8 +145,8 @@ class ViduTextToVideoNode(IO.ComfyNode):
resolution=resolution, resolution=resolution,
movement_amplitude=movement_amplitude, movement_amplitude=movement_amplitude,
) )
results = await execute_task(cls, VIDU_TEXT_TO_VIDEO, payload, 320) results = await execute_task(cls, VIDU_TEXT_TO_VIDEO, payload)
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) return IO.NodeOutput(await download_url_to_video_output(results[0].url))
class ViduImageToVideoNode(IO.ComfyNode): class ViduImageToVideoNode(IO.ComfyNode):
@ -230,12 +159,7 @@ class ViduImageToVideoNode(IO.ComfyNode):
category="api node/video/Vidu", category="api node/video/Vidu",
description="Generate video from image and optional prompt", description="Generate video from image and optional prompt",
inputs=[ inputs=[
IO.Combo.Input( IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
"model",
options=VideoModelName,
default=VideoModelName.vidu_q1,
tooltip="Model name",
),
IO.Image.Input( IO.Image.Input(
"image", "image",
tooltip="An image to be used as the start frame of the generated video", tooltip="An image to be used as the start frame of the generated video",
@ -270,15 +194,13 @@ class ViduImageToVideoNode(IO.ComfyNode):
), ),
IO.Combo.Input( IO.Combo.Input(
"resolution", "resolution",
options=Resolution, options=["1080p"],
default=Resolution.r_1080p,
tooltip="Supported values may vary by model & duration", tooltip="Supported values may vary by model & duration",
optional=True, optional=True,
), ),
IO.Combo.Input( IO.Combo.Input(
"movement_amplitude", "movement_amplitude",
options=MovementAmplitude, options=["auto", "small", "medium", "large"],
default=MovementAmplitude.auto.value,
tooltip="The movement amplitude of objects in the frame", tooltip="The movement amplitude of objects in the frame",
optional=True, optional=True,
), ),
@ -298,7 +220,7 @@ class ViduImageToVideoNode(IO.ComfyNode):
async def execute( async def execute(
cls, cls,
model: str, model: str,
image: torch.Tensor, image: Input.Image,
prompt: str, prompt: str,
duration: int, duration: int,
seed: int, seed: int,
@ -309,7 +231,7 @@ class ViduImageToVideoNode(IO.ComfyNode):
raise ValueError("Only one input image is allowed.") raise ValueError("Only one input image is allowed.")
validate_image_aspect_ratio(image, (1, 4), (4, 1)) validate_image_aspect_ratio(image, (1, 4), (4, 1))
payload = TaskCreationRequest( payload = TaskCreationRequest(
model_name=model, model=model,
prompt=prompt, prompt=prompt,
duration=duration, duration=duration,
seed=seed, seed=seed,
@ -322,8 +244,8 @@ class ViduImageToVideoNode(IO.ComfyNode):
max_images=1, max_images=1,
mime_type="image/png", mime_type="image/png",
) )
results = await execute_task(cls, VIDU_IMAGE_TO_VIDEO, payload, 120) results = await execute_task(cls, VIDU_IMAGE_TO_VIDEO, payload)
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) return IO.NodeOutput(await download_url_to_video_output(results[0].url))
class ViduReferenceVideoNode(IO.ComfyNode): class ViduReferenceVideoNode(IO.ComfyNode):
@ -334,14 +256,9 @@ class ViduReferenceVideoNode(IO.ComfyNode):
node_id="ViduReferenceVideoNode", node_id="ViduReferenceVideoNode",
display_name="Vidu Reference To Video Generation", display_name="Vidu Reference To Video Generation",
category="api node/video/Vidu", category="api node/video/Vidu",
description="Generate video from multiple images and prompt", description="Generate video from multiple images and a prompt",
inputs=[ inputs=[
IO.Combo.Input( IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
"model",
options=VideoModelName,
default=VideoModelName.vidu_q1,
tooltip="Model name",
),
IO.Image.Input( IO.Image.Input(
"images", "images",
tooltip="Images to use as references to generate a video with consistent subjects (max 7 images).", tooltip="Images to use as references to generate a video with consistent subjects (max 7 images).",
@ -374,22 +291,19 @@ class ViduReferenceVideoNode(IO.ComfyNode):
), ),
IO.Combo.Input( IO.Combo.Input(
"aspect_ratio", "aspect_ratio",
options=AspectRatio, options=["16:9", "9:16", "1:1"],
default=AspectRatio.r_16_9,
tooltip="The aspect ratio of the output video", tooltip="The aspect ratio of the output video",
optional=True, optional=True,
), ),
IO.Combo.Input( IO.Combo.Input(
"resolution", "resolution",
options=[model.value for model in Resolution], options=["1080p"],
default=Resolution.r_1080p.value,
tooltip="Supported values may vary by model & duration", tooltip="Supported values may vary by model & duration",
optional=True, optional=True,
), ),
IO.Combo.Input( IO.Combo.Input(
"movement_amplitude", "movement_amplitude",
options=[model.value for model in MovementAmplitude], options=["auto", "small", "medium", "large"],
default=MovementAmplitude.auto.value,
tooltip="The movement amplitude of objects in the frame", tooltip="The movement amplitude of objects in the frame",
optional=True, optional=True,
), ),
@ -409,7 +323,7 @@ class ViduReferenceVideoNode(IO.ComfyNode):
async def execute( async def execute(
cls, cls,
model: str, model: str,
images: torch.Tensor, images: Input.Image,
prompt: str, prompt: str,
duration: int, duration: int,
seed: int, seed: int,
@ -426,7 +340,7 @@ class ViduReferenceVideoNode(IO.ComfyNode):
validate_image_aspect_ratio(image, (1, 4), (4, 1)) validate_image_aspect_ratio(image, (1, 4), (4, 1))
validate_image_dimensions(image, min_width=128, min_height=128) validate_image_dimensions(image, min_width=128, min_height=128)
payload = TaskCreationRequest( payload = TaskCreationRequest(
model_name=model, model=model,
prompt=prompt, prompt=prompt,
duration=duration, duration=duration,
seed=seed, seed=seed,
@ -440,8 +354,8 @@ class ViduReferenceVideoNode(IO.ComfyNode):
max_images=7, max_images=7,
mime_type="image/png", mime_type="image/png",
) )
results = await execute_task(cls, VIDU_REFERENCE_VIDEO, payload, 120) results = await execute_task(cls, VIDU_REFERENCE_VIDEO, payload)
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) return IO.NodeOutput(await download_url_to_video_output(results[0].url))
class ViduStartEndToVideoNode(IO.ComfyNode): class ViduStartEndToVideoNode(IO.ComfyNode):
@ -454,12 +368,7 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
category="api node/video/Vidu", category="api node/video/Vidu",
description="Generate a video from start and end frames and a prompt", description="Generate a video from start and end frames and a prompt",
inputs=[ inputs=[
IO.Combo.Input( IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
"model",
options=[model.value for model in VideoModelName],
default=VideoModelName.vidu_q1.value,
tooltip="Model name",
),
IO.Image.Input( IO.Image.Input(
"first_frame", "first_frame",
tooltip="Start frame", tooltip="Start frame",
@ -497,15 +406,13 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
), ),
IO.Combo.Input( IO.Combo.Input(
"resolution", "resolution",
options=[model.value for model in Resolution], options=["1080p"],
default=Resolution.r_1080p.value,
tooltip="Supported values may vary by model & duration", tooltip="Supported values may vary by model & duration",
optional=True, optional=True,
), ),
IO.Combo.Input( IO.Combo.Input(
"movement_amplitude", "movement_amplitude",
options=[model.value for model in MovementAmplitude], options=["auto", "small", "medium", "large"],
default=MovementAmplitude.auto.value,
tooltip="The movement amplitude of objects in the frame", tooltip="The movement amplitude of objects in the frame",
optional=True, optional=True,
), ),
@ -525,8 +432,8 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
async def execute( async def execute(
cls, cls,
model: str, model: str,
first_frame: torch.Tensor, first_frame: Input.Image,
end_frame: torch.Tensor, end_frame: Input.Image,
prompt: str, prompt: str,
duration: int, duration: int,
seed: int, seed: int,
@ -535,7 +442,7 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False) validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
payload = TaskCreationRequest( payload = TaskCreationRequest(
model_name=model, model=model,
prompt=prompt, prompt=prompt,
duration=duration, duration=duration,
seed=seed, seed=seed,
@ -546,8 +453,391 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
(await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0] (await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0]
for frame in (first_frame, end_frame) for frame in (first_frame, end_frame)
] ]
results = await execute_task(cls, VIDU_START_END_VIDEO, payload, 96) results = await execute_task(cls, VIDU_START_END_VIDEO, payload)
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) return IO.NodeOutput(await download_url_to_video_output(results[0].url))
class Vidu2TextToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="Vidu2TextToVideoNode",
display_name="Vidu2 Text-to-Video Generation",
category="api node/video/Vidu",
description="Generate video from a text prompt",
inputs=[
IO.Combo.Input("model", options=["viduq2"]),
IO.String.Input(
"prompt",
multiline=True,
tooltip="A textual description for video generation, with a maximum length of 2000 characters.",
),
IO.Int.Input(
"duration",
default=5,
min=1,
max=10,
step=1,
display_mode=IO.NumberDisplay.slider,
),
IO.Int.Input(
"seed",
default=1,
min=0,
max=2147483647,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
),
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "3:4", "4:3", "1:1"]),
IO.Combo.Input("resolution", options=["720p", "1080p"]),
IO.Boolean.Input(
"background_music",
default=False,
tooltip="Whether to add background music to the generated video.",
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model: str,
prompt: str,
duration: int,
seed: int,
aspect_ratio: str,
resolution: str,
background_music: bool,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2000)
results = await execute_task(
cls,
VIDU_TEXT_TO_VIDEO,
TaskCreationRequest(
model=model,
prompt=prompt,
duration=duration,
seed=seed,
aspect_ratio=aspect_ratio,
resolution=resolution,
bgm=background_music,
),
)
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
class Vidu2ImageToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="Vidu2ImageToVideoNode",
display_name="Vidu2 Image-to-Video Generation",
category="api node/video/Vidu",
description="Generate a video from an image and an optional prompt.",
inputs=[
IO.Combo.Input("model", options=["viduq2-pro-fast", "viduq2-pro", "viduq2-turbo"]),
IO.Image.Input(
"image",
tooltip="An image to be used as the start frame of the generated video.",
),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="An optional text prompt for video generation (max 2000 characters).",
),
IO.Int.Input(
"duration",
default=5,
min=1,
max=10,
step=1,
display_mode=IO.NumberDisplay.slider,
),
IO.Int.Input(
"seed",
default=1,
min=0,
max=2147483647,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
),
IO.Combo.Input(
"resolution",
options=["720p", "1080p"],
),
IO.Combo.Input(
"movement_amplitude",
options=["auto", "small", "medium", "large"],
tooltip="The movement amplitude of objects in the frame.",
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model: str,
image: Input.Image,
prompt: str,
duration: int,
seed: int,
resolution: str,
movement_amplitude: str,
) -> IO.NodeOutput:
if get_number_of_images(image) > 1:
raise ValueError("Only one input image is allowed.")
validate_image_aspect_ratio(image, (1, 4), (4, 1))
validate_string(prompt, max_length=2000)
results = await execute_task(
cls,
VIDU_IMAGE_TO_VIDEO,
TaskCreationRequest(
model=model,
prompt=prompt,
duration=duration,
seed=seed,
resolution=resolution,
movement_amplitude=movement_amplitude,
images=await upload_images_to_comfyapi(
cls,
image,
max_images=1,
mime_type="image/png",
),
),
)
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
class Vidu2ReferenceVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="Vidu2ReferenceVideoNode",
display_name="Vidu2 Reference-to-Video Generation",
category="api node/video/Vidu",
description="Generate a video from multiple reference images and a prompt.",
inputs=[
IO.Combo.Input("model", options=["viduq2"]),
IO.Autogrow.Input(
"subjects",
template=IO.Autogrow.TemplateNames(
IO.Image.Input("reference_images"),
names=["subject1", "subject2", "subject3"],
min=1,
),
tooltip="For each subject, provide up to 3 reference images (7 images total across all subjects). "
"Reference them in prompts via @subject{subject_id}.",
),
IO.String.Input(
"prompt",
multiline=True,
tooltip="When enabled, the video will include generated speech and background music "
"based on the prompt.",
),
IO.Boolean.Input(
"audio",
default=False,
tooltip="When enabled video will contain generated speech and background music based on the prompt.",
),
IO.Int.Input(
"duration",
default=5,
min=1,
max=10,
step=1,
display_mode=IO.NumberDisplay.slider,
),
IO.Int.Input(
"seed",
default=1,
min=0,
max=2147483647,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
),
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "4:3", "3:4", "1:1"]),
IO.Combo.Input("resolution", options=["720p"]),
IO.Combo.Input(
"movement_amplitude",
options=["auto", "small", "medium", "large"],
tooltip="The movement amplitude of objects in the frame.",
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model: str,
subjects: IO.Autogrow.Type,
prompt: str,
audio: bool,
duration: int,
seed: int,
aspect_ratio: str,
resolution: str,
movement_amplitude: str,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2000)
total_images = 0
for i in subjects:
if get_number_of_images(subjects[i]) > 3:
raise ValueError("Maximum number of images per subject is 3.")
for im in subjects[i]:
total_images += 1
validate_image_aspect_ratio(im, (1, 4), (4, 1))
validate_image_dimensions(im, min_width=128, min_height=128)
if total_images > 7:
raise ValueError("Too many reference images; the maximum allowed is 7.")
subjects_param: list[SubjectReference] = []
for i in subjects:
subjects_param.append(
SubjectReference(
id=i,
images=await upload_images_to_comfyapi(
cls,
subjects[i],
max_images=3,
mime_type="image/png",
wait_label=f"Uploading reference images for {i}",
),
),
)
payload = TaskCreationRequest(
model=model,
prompt=prompt,
audio=audio,
duration=duration,
seed=seed,
aspect_ratio=aspect_ratio,
resolution=resolution,
movement_amplitude=movement_amplitude,
subjects=subjects_param,
)
results = await execute_task(cls, VIDU_REFERENCE_VIDEO, payload)
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
class Vidu2StartEndToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="Vidu2StartEndToVideoNode",
display_name="Vidu2 Start/End Frame-to-Video Generation",
category="api node/video/Vidu",
description="Generate a video from a start frame, an end frame, and a prompt.",
inputs=[
IO.Combo.Input("model", options=["viduq2-pro-fast", "viduq2-pro", "viduq2-turbo"]),
IO.Image.Input("first_frame"),
IO.Image.Input("end_frame"),
IO.String.Input(
"prompt",
multiline=True,
tooltip="Prompt description (max 2000 characters).",
),
IO.Int.Input(
"duration",
default=5,
min=2,
max=8,
step=1,
display_mode=IO.NumberDisplay.slider,
),
IO.Int.Input(
"seed",
default=1,
min=0,
max=2147483647,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
),
IO.Combo.Input("resolution", options=["720p", "1080p"]),
IO.Combo.Input(
"movement_amplitude",
options=["auto", "small", "medium", "large"],
tooltip="The movement amplitude of objects in the frame.",
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model: str,
first_frame: Input.Image,
end_frame: Input.Image,
prompt: str,
duration: int,
seed: int,
resolution: str,
movement_amplitude: str,
) -> IO.NodeOutput:
validate_string(prompt, max_length=2000)
if get_number_of_images(first_frame) > 1:
raise ValueError("Only one input image is allowed for `first_frame`.")
if get_number_of_images(end_frame) > 1:
raise ValueError("Only one input image is allowed for `end_frame`.")
validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
payload = TaskCreationRequest(
model=model,
prompt=prompt,
duration=duration,
seed=seed,
resolution=resolution,
movement_amplitude=movement_amplitude,
images=[
(await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0]
for frame in (first_frame, end_frame)
],
)
results = await execute_task(cls, VIDU_START_END_VIDEO, payload)
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
class ViduExtension(ComfyExtension): class ViduExtension(ComfyExtension):
@ -558,6 +848,10 @@ class ViduExtension(ComfyExtension):
ViduImageToVideoNode, ViduImageToVideoNode,
ViduReferenceVideoNode, ViduReferenceVideoNode,
ViduStartEndToVideoNode, ViduStartEndToVideoNode,
Vidu2TextToVideoNode,
Vidu2ImageToVideoNode,
Vidu2ReferenceVideoNode,
Vidu2StartEndToVideoNode,
] ]

View File

@ -55,7 +55,7 @@ def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> to
def tensor_to_bytesio( def tensor_to_bytesio(
image: torch.Tensor, image: torch.Tensor,
name: str | None = None, *,
total_pixels: int = 2048 * 2048, total_pixels: int = 2048 * 2048,
mime_type: str = "image/png", mime_type: str = "image/png",
) -> BytesIO: ) -> BytesIO:
@ -75,7 +75,7 @@ def tensor_to_bytesio(
pil_image = tensor_to_pil(image, total_pixels=total_pixels) pil_image = tensor_to_pil(image, total_pixels=total_pixels)
img_binary = pil_to_bytesio(pil_image, mime_type=mime_type) img_binary = pil_to_bytesio(pil_image, mime_type=mime_type)
img_binary.name = f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}" img_binary.name = f"{uuid.uuid4()}.{mimetype_to_extension(mime_type)}"
return img_binary return img_binary

View File

@ -49,6 +49,7 @@ async def upload_images_to_comfyapi(
mime_type: str | None = None, mime_type: str | None = None,
wait_label: str | None = "Uploading", wait_label: str | None = "Uploading",
show_batch_index: bool = True, show_batch_index: bool = True,
total_pixels: int = 2048 * 2048,
) -> list[str]: ) -> list[str]:
""" """
Uploads images to ComfyUI API and returns download URLs. Uploads images to ComfyUI API and returns download URLs.
@ -63,7 +64,7 @@ async def upload_images_to_comfyapi(
for idx in range(num_to_upload): for idx in range(num_to_upload):
tensor = image[idx] if is_batch else image tensor = image[idx] if is_batch else image
img_io = tensor_to_bytesio(tensor, mime_type=mime_type) img_io = tensor_to_bytesio(tensor, total_pixels=total_pixels, mime_type=mime_type)
effective_label = wait_label effective_label = wait_label
if wait_label and show_batch_index and num_to_upload > 1: if wait_label and show_batch_index and num_to_upload > 1:
@ -81,7 +82,6 @@ async def upload_audio_to_comfyapi(
container_format: str = "mp4", container_format: str = "mp4",
codec_name: str = "aac", codec_name: str = "aac",
mime_type: str = "audio/mp4", mime_type: str = "audio/mp4",
filename: str = "uploaded_audio.mp4",
) -> str: ) -> str:
""" """
Uploads a single audio input to ComfyUI API and returns its download URL. Uploads a single audio input to ComfyUI API and returns its download URL.
@ -91,7 +91,7 @@ async def upload_audio_to_comfyapi(
waveform: torch.Tensor = audio["waveform"] waveform: torch.Tensor = audio["waveform"]
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform) audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name) audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name)
return await upload_file_to_comfyapi(cls, audio_bytes_io, filename, mime_type) return await upload_file_to_comfyapi(cls, audio_bytes_io, f"{uuid.uuid4()}.{container_format}", mime_type)
async def upload_video_to_comfyapi( async def upload_video_to_comfyapi(

View File

@ -14,8 +14,9 @@ class JobStatus:
IN_PROGRESS = 'in_progress' IN_PROGRESS = 'in_progress'
COMPLETED = 'completed' COMPLETED = 'completed'
FAILED = 'failed' FAILED = 'failed'
CANCELLED = 'cancelled'
ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED] ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED, CANCELLED]
# Media types that can be previewed in the frontend # Media types that can be previewed in the frontend
@ -94,12 +95,6 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs:
status_info = history_item.get('status', {}) status_info = history_item.get('status', {})
status_str = status_info.get('status_str') if status_info else None status_str = status_info.get('status_str') if status_info else None
if status_str == 'success':
status = JobStatus.COMPLETED
elif status_str == 'error':
status = JobStatus.FAILED
else:
status = JobStatus.COMPLETED
outputs = history_item.get('outputs', {}) outputs = history_item.get('outputs', {})
outputs_count, preview_output = get_outputs_summary(outputs) outputs_count, preview_output = get_outputs_summary(outputs)
@ -107,6 +102,7 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs:
execution_error = None execution_error = None
execution_start_time = None execution_start_time = None
execution_end_time = None execution_end_time = None
was_interrupted = False
if status_info: if status_info:
messages = status_info.get('messages', []) messages = status_info.get('messages', [])
for entry in messages: for entry in messages:
@ -119,6 +115,15 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs:
execution_end_time = event_data.get('timestamp') execution_end_time = event_data.get('timestamp')
if event_name == 'execution_error': if event_name == 'execution_error':
execution_error = event_data execution_error = event_data
elif event_name == 'execution_interrupted':
was_interrupted = True
if status_str == 'success':
status = JobStatus.COMPLETED
elif status_str == 'error':
status = JobStatus.CANCELLED if was_interrupted else JobStatus.FAILED
else:
status = JobStatus.COMPLETED
job = prune_dict({ job = prune_dict({
'id': prompt_id, 'id': prompt_id,
@ -268,13 +273,13 @@ def get_all_jobs(
for item in queued: for item in queued:
jobs.append(normalize_queue_item(item, JobStatus.PENDING)) jobs.append(normalize_queue_item(item, JobStatus.PENDING))
include_completed = JobStatus.COMPLETED in status_filter history_statuses = {JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED}
include_failed = JobStatus.FAILED in status_filter requested_history_statuses = history_statuses & set(status_filter)
if include_completed or include_failed: if requested_history_statuses:
for prompt_id, history_item in history.items(): for prompt_id, history_item in history.items():
is_failed = history_item.get('status', {}).get('status_str') == 'error' job = normalize_history_item(prompt_id, history_item)
if (is_failed and include_failed) or (not is_failed and include_completed): if job.get('status') in requested_history_statuses:
jobs.append(normalize_history_item(prompt_id, history_item)) jobs.append(job)
if workflow_id: if workflow_id:
jobs = [j for j in jobs if j.get('workflow_id') == workflow_id] jobs = [j for j in jobs if j.get('workflow_id') == workflow_id]

View File

@ -399,6 +399,58 @@ class SplitAudioChannels(IO.ComfyNode):
separate = execute # TODO: remove separate = execute # TODO: remove
class JoinAudioChannels(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="JoinAudioChannels",
display_name="Join Audio Channels",
description="Joins left and right mono audio channels into a stereo audio.",
category="audio",
inputs=[
IO.Audio.Input("audio_left"),
IO.Audio.Input("audio_right"),
],
outputs=[
IO.Audio.Output(display_name="audio"),
],
)
@classmethod
def execute(cls, audio_left, audio_right) -> IO.NodeOutput:
waveform_left = audio_left["waveform"]
sample_rate_left = audio_left["sample_rate"]
waveform_right = audio_right["waveform"]
sample_rate_right = audio_right["sample_rate"]
if waveform_left.shape[1] != 1 or waveform_right.shape[1] != 1:
raise ValueError("AudioJoin: Both input audios must be mono.")
# Handle different sample rates by resampling to the higher rate
waveform_left, waveform_right, output_sample_rate = match_audio_sample_rates(
waveform_left, sample_rate_left, waveform_right, sample_rate_right
)
# Handle different lengths by trimming to the shorter length
length_left = waveform_left.shape[-1]
length_right = waveform_right.shape[-1]
if length_left != length_right:
min_length = min(length_left, length_right)
if length_left > min_length:
logging.info(f"JoinAudioChannels: Trimming left channel from {length_left} to {min_length} samples.")
waveform_left = waveform_left[..., :min_length]
if length_right > min_length:
logging.info(f"JoinAudioChannels: Trimming right channel from {length_right} to {min_length} samples.")
waveform_right = waveform_right[..., :min_length]
# Join the channels into stereo
left_channel = waveform_left[..., 0:1, :]
right_channel = waveform_right[..., 0:1, :]
stereo_waveform = torch.cat([left_channel, right_channel], dim=1)
return IO.NodeOutput({"waveform": stereo_waveform, "sample_rate": output_sample_rate})
def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2): def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2):
if sample_rate_1 != sample_rate_2: if sample_rate_1 != sample_rate_2:
@ -616,6 +668,7 @@ class AudioExtension(ComfyExtension):
RecordAudio, RecordAudio,
TrimAudioDuration, TrimAudioDuration,
SplitAudioChannels, SplitAudioChannels,
JoinAudioChannels,
AudioConcat, AudioConcat,
AudioMerge, AudioMerge,
AudioAdjustVolume, AudioAdjustVolume,

View File

@ -0,0 +1,53 @@
import nodes
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension
class ImageCompare(IO.ComfyNode):
"""Compares two images with a slider interface."""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ImageCompare",
display_name="Image Compare",
description="Compares two images side by side with a slider.",
category="image",
is_experimental=True,
is_output_node=True,
inputs=[
IO.Image.Input("image_a", optional=True),
IO.Image.Input("image_b", optional=True),
IO.ImageCompare.Input("compare_view"),
],
outputs=[],
)
@classmethod
def execute(cls, image_a=None, image_b=None, compare_view=None) -> IO.NodeOutput:
result = {"a_images": [], "b_images": []}
preview_node = nodes.PreviewImage()
if image_a is not None and len(image_a) > 0:
saved = preview_node.save_images(image_a, "comfy.compare.a")
result["a_images"] = saved["ui"]["images"]
if image_b is not None and len(image_b) > 0:
saved = preview_node.save_images(image_b, "comfy.compare.b")
result["b_images"] = saved["ui"]["images"]
return IO.NodeOutput(ui=result)
class ImageCompareExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
ImageCompare,
]
async def comfy_entrypoint() -> ImageCompareExtension:
return ImageCompareExtension()

View File

@ -185,6 +185,10 @@ class LTXAVTextEncoderLoader(io.ComfyNode):
io.Combo.Input( io.Combo.Input(
"ckpt_name", "ckpt_name",
options=folder_paths.get_filename_list("checkpoints"), options=folder_paths.get_filename_list("checkpoints"),
),
io.Combo.Input(
"device",
options=["default", "cpu"],
) )
], ],
outputs=[io.Clip.Output()], outputs=[io.Clip.Output()],
@ -197,7 +201,11 @@ class LTXAVTextEncoderLoader(io.ComfyNode):
clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", text_encoder) clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", text_encoder)
clip_path2 = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) clip_path2 = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type) model_options = {}
if device == "cpu":
model_options["load_device"] = model_options["offload_device"] = torch.device("cpu")
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options)
return io.NodeOutput(clip) return io.NodeOutput(clip)

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is # This file is automatically generated by the build process when version is
# updated in pyproject.toml. # updated in pyproject.toml.
__version__ = "0.8.0" __version__ = "0.9.1"

View File

@ -7,6 +7,7 @@ import folder_paths
import time import time
from comfy.cli_args import args from comfy.cli_args import args
from app.logger import setup_logger from app.logger import setup_logger
from app.assets.scanner import seed_assets
import itertools import itertools
import utils.extra_config import utils.extra_config
import logging import logging
@ -324,6 +325,8 @@ def setup_database():
from app.database.db import init_db, dependencies_available from app.database.db import init_db, dependencies_available
if dependencies_available(): if dependencies_available():
init_db() init_db()
if not args.disable_assets_autoscan:
seed_assets(["models"], enable_logging=True)
except Exception as e: except Exception as e:
logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}") logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}")

View File

@ -1 +1 @@
comfyui_manager==4.0.4 comfyui_manager==4.0.5

View File

@ -378,14 +378,15 @@ class VAEEncodeForInpaint:
CATEGORY = "latent/inpaint" CATEGORY = "latent/inpaint"
def encode(self, vae, pixels, mask, grow_mask_by=6): def encode(self, vae, pixels, mask, grow_mask_by=6):
x = (pixels.shape[1] // vae.downscale_ratio) * vae.downscale_ratio downscale_ratio = vae.spacial_compression_encode()
y = (pixels.shape[2] // vae.downscale_ratio) * vae.downscale_ratio x = (pixels.shape[1] // downscale_ratio) * downscale_ratio
y = (pixels.shape[2] // downscale_ratio) * downscale_ratio
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear")
pixels = pixels.clone() pixels = pixels.clone()
if pixels.shape[1] != x or pixels.shape[2] != y: if pixels.shape[1] != x or pixels.shape[2] != y:
x_offset = (pixels.shape[1] % vae.downscale_ratio) // 2 x_offset = (pixels.shape[1] % downscale_ratio) // 2
y_offset = (pixels.shape[2] % vae.downscale_ratio) // 2 y_offset = (pixels.shape[2] % downscale_ratio) // 2
pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:] pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:]
mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset] mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset]
@ -2369,6 +2370,7 @@ async def init_builtin_extra_nodes():
"nodes_nop.py", "nodes_nop.py",
"nodes_kandinsky5.py", "nodes_kandinsky5.py",
"nodes_wanmove.py", "nodes_wanmove.py",
"nodes_image_compare.py",
] ]
import_failed = [] import_failed = []

View File

@ -1,6 +1,6 @@
[project] [project]
name = "ComfyUI" name = "ComfyUI"
version = "0.8.0" version = "0.9.1"
readme = "README.md" readme = "README.md"
license = { file = "LICENSE" } license = { file = "LICENSE" }
requires-python = ">=3.10" requires-python = ">=3.10"

View File

@ -1,6 +1,6 @@
comfyui-frontend-package==1.35.9 comfyui-frontend-package==1.36.14
comfyui-workflow-templates==0.7.67 comfyui-workflow-templates==0.8.4
comfyui-embedded-docs==0.3.1 comfyui-embedded-docs==0.4.0
torch torch
torchsde torchsde
torchvision torchvision
@ -21,7 +21,7 @@ psutil
alembic alembic
SQLAlchemy SQLAlchemy
av>=14.2.0 av>=14.2.0
comfy-kitchen>=0.2.3 comfy-kitchen>=0.2.6
#non essential dependencies: #non essential dependencies:
kornia>=0.7.1 kornia>=0.7.1

View File

@ -33,6 +33,8 @@ import node_helpers
from comfyui_version import __version__ from comfyui_version import __version__
from app.frontend_management import FrontendManager, parse_version from app.frontend_management import FrontendManager, parse_version
from comfy_api.internal import _ComfyNodeInternal from comfy_api.internal import _ComfyNodeInternal
from app.assets.scanner import seed_assets
from app.assets.api.routes import register_assets_system
from app.user_manager import UserManager from app.user_manager import UserManager
from app.model_manager import ModelFileManager from app.model_manager import ModelFileManager
@ -184,7 +186,7 @@ def create_block_external_middleware():
else: else:
response = await handler(request) response = await handler(request)
response.headers['Content-Security-Policy'] = "default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval' blob:; style-src 'self' 'unsafe-inline'; img-src 'self' data: blob:; font-src 'self'; connect-src 'self'; frame-src 'self'; object-src 'self';" response.headers['Content-Security-Policy'] = "default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval' blob:; style-src 'self' 'unsafe-inline'; img-src 'self' data: blob:; font-src 'self'; connect-src 'self' data:; frame-src 'self'; object-src 'self';"
return response return response
return block_external_middleware return block_external_middleware
@ -235,6 +237,7 @@ class PromptServer():
else args.front_end_root else args.front_end_root
) )
logging.info(f"[Prompt Server] web root: {self.web_root}") logging.info(f"[Prompt Server] web root: {self.web_root}")
register_assets_system(self.app, self.user_manager)
routes = web.RouteTableDef() routes = web.RouteTableDef()
self.routes = routes self.routes = routes
self.last_node_id = None self.last_node_id = None
@ -683,6 +686,7 @@ class PromptServer():
@routes.get("/object_info") @routes.get("/object_info")
async def get_object_info(request): async def get_object_info(request):
seed_assets(["models"])
with folder_paths.cache_helper: with folder_paths.cache_helper:
out = {} out = {}
for x in nodes.NODE_CLASS_MAPPINGS: for x in nodes.NODE_CLASS_MAPPINGS:

View File

@ -153,9 +153,9 @@ class TestMixedPrecisionOps(unittest.TestCase):
state_dict2 = model.state_dict() state_dict2 = model.state_dict()
# Verify layer1.weight is a QuantizedTensor with scale preserved # Verify layer1.weight is a QuantizedTensor with scale preserved
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor) self.assertTrue(torch.equal(state_dict2["layer1.weight"].view(torch.uint8), fp8_weight.view(torch.uint8)))
self.assertEqual(state_dict2["layer1.weight"]._params.scale.item(), 3.0) self.assertEqual(state_dict2["layer1.weight_scale"].item(), 3.0)
self.assertEqual(state_dict2["layer1.weight"]._layout_cls, "TensorCoreFP8E4M3Layout") self.assertEqual(model.layer1.weight._layout_cls, "TensorCoreFP8E4M3Layout")
# Verify non-quantized layers are standard tensors # Verify non-quantized layers are standard tensors
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor) self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)

View File

@ -19,6 +19,7 @@ class TestJobStatus:
assert JobStatus.IN_PROGRESS == 'in_progress' assert JobStatus.IN_PROGRESS == 'in_progress'
assert JobStatus.COMPLETED == 'completed' assert JobStatus.COMPLETED == 'completed'
assert JobStatus.FAILED == 'failed' assert JobStatus.FAILED == 'failed'
assert JobStatus.CANCELLED == 'cancelled'
def test_all_contains_all_statuses(self): def test_all_contains_all_statuses(self):
"""ALL should contain all status values.""" """ALL should contain all status values."""
@ -26,7 +27,8 @@ class TestJobStatus:
assert JobStatus.IN_PROGRESS in JobStatus.ALL assert JobStatus.IN_PROGRESS in JobStatus.ALL
assert JobStatus.COMPLETED in JobStatus.ALL assert JobStatus.COMPLETED in JobStatus.ALL
assert JobStatus.FAILED in JobStatus.ALL assert JobStatus.FAILED in JobStatus.ALL
assert len(JobStatus.ALL) == 4 assert JobStatus.CANCELLED in JobStatus.ALL
assert len(JobStatus.ALL) == 5
class TestIsPreviewable: class TestIsPreviewable:
@ -336,6 +338,40 @@ class TestNormalizeHistoryItem:
assert job['execution_error']['node_type'] == 'KSampler' assert job['execution_error']['node_type'] == 'KSampler'
assert job['execution_error']['exception_message'] == 'CUDA out of memory' assert job['execution_error']['exception_message'] == 'CUDA out of memory'
def test_cancelled_job(self):
"""Cancelled/interrupted history item should have cancelled status."""
history_item = {
'prompt': (
5,
'prompt-cancelled',
{'nodes': {}},
{'create_time': 1234567890000},
['node1'],
),
'status': {
'status_str': 'error',
'completed': False,
'messages': [
('execution_start', {'prompt_id': 'prompt-cancelled', 'timestamp': 1234567890500}),
('execution_interrupted', {
'prompt_id': 'prompt-cancelled',
'node_id': '5',
'node_type': 'KSampler',
'executed': ['1', '2', '3'],
'timestamp': 1234567891000,
})
]
},
'outputs': {},
}
job = normalize_history_item('prompt-cancelled', history_item)
assert job['status'] == 'cancelled'
assert job['execution_start_time'] == 1234567890500
assert job['execution_end_time'] == 1234567891000
# Cancelled jobs should not have execution_error set
assert 'execution_error' not in job
def test_include_outputs(self): def test_include_outputs(self):
"""When include_outputs=True, should include full output data.""" """When include_outputs=True, should include full output data."""
history_item = { history_item = {