Compare commits

...

37 Commits

Author SHA1 Message Date
Benjamin Lu
9304e47351
Update workflows for new release process (#11064)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.9) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
* Update release workflows for branch process

* Adjust branch order in workflow triggers

* Revert changes in test workflows
2025-12-15 23:24:18 -08:00
comfyanonymous
bc606d7d64
Add a way to set the default ref method in the qwen image code. (#11349) 2025-12-16 01:26:55 -05:00
comfyanonymous
645ee1881e
Inpainting for z image fun control. Use the ZImageFunControlnet node. (#11346)
image -> control image ex: pose
inpaint_image -> image for inpainting
mask -> inpaint mask
2025-12-15 23:38:12 -05:00
Christian Byrne
3d082c3206
bump comfyui-frontend-package to 1.34.9 (patch) (#11342) 2025-12-15 23:35:37 -05:00
comfyanonymous
683569de55
Only enable fp16 on ZImage on newer pytorch. (#11344) 2025-12-15 22:33:27 -05:00
Haoming
ea2c117bc3
[BlockInfo] Wan (#10845)
* block info

* animate

* tensor

* device

* revert
2025-12-15 17:59:16 -08:00
Haoming
fc4af86068
[BlockInfo] Lumina (#11227)
* block info

* device

* Make tensor int again

---------

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2025-12-15 17:57:28 -08:00
comfyanonymous
41bcf0619d
Add code to detect if a z image fun controlnet is broken or not. (#11341) 2025-12-15 20:51:06 -05:00
seed93
d02d0e5744
[add] tripo3.0 (#10663)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.9) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
* [add] tripo3.0

* [tripo] change paramter order

* change order

---------

Co-authored-by: liangd <liangding@vastai3d.com>
2025-12-15 17:38:46 -08:00
comfyanonymous
70541d4e77
Support the new qwen edit 2511 reference method. (#11340)
index_timestep_zero can be selected in the
FluxKontextMultiReferenceLatentMethod now with the display name set to the
more generic "Edit Model Reference Method" node.
2025-12-15 19:20:34 -05:00
drozbay
77b2f7c228
Add context windows callback for custom cond handling (#11208)
Co-authored-by: ozbayb <17261091+ozbayb@users.noreply.github.com>
2025-12-15 16:06:32 -08:00
Alexander Piskun
43e0d4e3cc
comfy_api: remove usage of "Type","List" and "Dict" types (#11238) 2025-12-15 16:01:10 -08:00
Dr.Lt.Data
dbd330454a
feat(preview): add per-queue live preview method override (#11261)
- Add set_preview_method() to override live preview method per queue item
- Read extra_data.preview_method from /prompt request
- Support values: taesd, latent2rgb, none, auto, default
- "default" or unset uses server's CLI --preview-method setting
- Add 44 tests (37 unit + 7 E2E)
2025-12-15 15:57:39 -08:00
Alexander Piskun
33c7f1179d
drop Pika API nodes (#11306) 2025-12-15 15:32:29 -08:00
Alexander Piskun
af91eb6c99
api-nodes: drop Kling v1 model (#11307) 2025-12-15 15:30:24 -08:00
comfyanonymous
5cb1e0c9a0
Disable guards on transformer_options when torch.compile (#11317) 2025-12-15 16:49:29 -05:00
ComfyUI Wiki
51347f9fb8
chore: update workflow templates to v0.7.59 (#11337) 2025-12-15 16:28:55 -05:00
Dr.Lt.Data
a5e85017d8
bump manager requirments to the 4.0.3b5 (#11324) 2025-12-15 14:24:01 -05:00
comfyanonymous
5ac3b26a7d
Update warning for old pytorch version. (#11319)
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Versions below 2.4 are no longer supported. We will not break support on purpose but will not fix it if we do.
2025-12-14 04:02:50 -05:00
chaObserv
6592bffc60
seeds_2: add phi_2 variant and sampler node (#11309)
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Generate Pydantic Stubs from api.comfy.org / generate-models (push) Has been cancelled
* Add phi_2 solver type to seeds_2

* Add sampler node of seeds_2
2025-12-14 00:03:29 -05:00
comfyanonymous
971cefe7d4
Fix pytorch warnings. (#11314)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2025-12-13 18:45:23 -05:00
comfyanonymous
da2bfb5b0a
Basic implementation of z image fun control union 2.0 (#11304)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
The inpaint part is currently missing and will be implemented later.

I think they messed up this model pretty bad. They added some
control_noise_refiner blocks but don't actually use them. There is a typo
in their code so instead of doing control_noise_refiner -> control_layers
it runs the whole control_layers twice.

Unfortunately they trained with this typo so the model works but is kind
of slow and would probably perform a lot better if they corrected their
code and trained it again.
2025-12-13 01:39:11 -05:00
comfyanonymous
c5a47a1692
Fix bias dtype issue in mixed ops. (#11293)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2025-12-12 11:49:35 -05:00
Alexander Piskun
908fd7d749
feat(api-nodes): new TextToVideoWithAudio and ImageToVideoWithAudio nodes (#11267)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2025-12-12 00:18:31 -08:00
comfyanonymous
5495589db3
Respect the dtype the op was initialized in for non quant mixed op. (#11282)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2025-12-11 23:32:27 -05:00
Jukka Seppänen
982876d59a
WanMove support (#11247) 2025-12-11 22:29:34 -05:00
comfyanonymous
338d9ae3bb
Make portable updater work with repos in unmerged state. (#11281) 2025-12-11 18:56:33 -05:00
comfyanonymous
eeb020b9b7
Better chroma radiance and other models vram estimation. (#11278)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
2025-12-11 17:33:09 -05:00
comfyanonymous
ae65433a60
This only works on radiance. (#11277) 2025-12-11 17:15:00 -05:00
comfyanonymous
fdebe18296
Fix regular chroma radiance (#11276) 2025-12-11 17:09:35 -05:00
comfyanonymous
f8321eb57b
Adjust memory usage factor. (#11257)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2025-12-11 01:30:31 -05:00
Alexander Piskun
93948e3fc5
feat(api-nodes): enable Kling Omni O1 node (#11229) 2025-12-10 22:11:12 -08:00
Farshore
e711aaf1a7
Lower VAE loading requirements:Create a new branch for GPU memory calculations in qwen-image vae (#11199) 2025-12-10 22:02:26 -05:00
Johnpaul Chiwetelu
57ddb7fd13
Fix: filter hidden files from /internal/files endpoint (#11191) 2025-12-10 21:49:49 -05:00
comfyanonymous
17c92a9f28
Tweak Z Image memory estimation. (#11254) 2025-12-10 19:59:48 -05:00
Alexander Piskun
36357bbcc3
process the NodeV1 dict results correctly (#11237)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2025-12-10 11:55:09 -08:00
Benjamin Lu
f668c2e3c9
bump comfyui-frontend-package to 1.34.8 (#11220)
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.9) (push) Has been cancelled
2025-12-09 22:27:07 -05:00
48 changed files with 2003 additions and 860 deletions

View File

@ -53,6 +53,16 @@ try:
repo.stash(ident) repo.stash(ident)
except KeyError: except KeyError:
print("nothing to stash") # noqa: T201 print("nothing to stash") # noqa: T201
except:
print("Could not stash, cleaning index and trying again.") # noqa: T201
repo.state_cleanup()
repo.index.read_tree(repo.head.peel().tree)
repo.index.write()
try:
repo.stash(ident)
except KeyError:
print("nothing to stash.") # noqa: T201
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S')) backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
print("creating backup branch: {}".format(backup_branch_name)) # noqa: T201 print("creating backup branch: {}".format(backup_branch_name)) # noqa: T201
try: try:

View File

@ -5,6 +5,7 @@ on:
push: push:
branches: branches:
- master - master
- release/**
paths-ignore: paths-ignore:
- 'app/**' - 'app/**'
- 'input/**' - 'input/**'

View File

@ -2,9 +2,9 @@ name: Execution Tests
on: on:
push: push:
branches: [ main, master ] branches: [ main, master, release/** ]
pull_request: pull_request:
branches: [ main, master ] branches: [ main, master, release/** ]
jobs: jobs:
test: test:

View File

@ -2,9 +2,9 @@ name: Test server launches without errors
on: on:
push: push:
branches: [ main, master ] branches: [ main, master, release/** ]
pull_request: pull_request:
branches: [ main, master ] branches: [ main, master, release/** ]
jobs: jobs:
test: test:

View File

@ -2,9 +2,9 @@ name: Unit Tests
on: on:
push: push:
branches: [ main, master ] branches: [ main, master, release/** ]
pull_request: pull_request:
branches: [ main, master ] branches: [ main, master, release/** ]
jobs: jobs:
test: test:

View File

@ -6,6 +6,7 @@ on:
- "pyproject.toml" - "pyproject.toml"
branches: branches:
- master - master
- release/**
jobs: jobs:
update-version: update-version:

View File

@ -58,8 +58,13 @@ class InternalRoutes:
return web.json_response({"error": "Invalid directory type"}, status=400) return web.json_response({"error": "Invalid directory type"}, status=400)
directory = get_directory_by_type(directory_type) directory = get_directory_by_type(directory_type)
def is_visible_file(entry: os.DirEntry) -> bool:
"""Filter out hidden files (e.g., .DS_Store on macOS)."""
return entry.is_file() and not entry.name.startswith('.')
sorted_files = sorted( sorted_files = sorted(
(entry for entry in os.scandir(directory) if entry.is_file()), (entry for entry in os.scandir(directory) if is_visible_file(entry)),
key=lambda entry: -entry.stat().st_mtime key=lambda entry: -entry.stat().st_mtime
) )
return web.json_response([entry.name for entry in sorted_files], status=200) return web.json_response([entry.name for entry in sorted_files], status=200)

View File

@ -97,6 +97,13 @@ class LatentPreviewMethod(enum.Enum):
Latent2RGB = "latent2rgb" Latent2RGB = "latent2rgb"
TAESD = "taesd" TAESD = "taesd"
@classmethod
def from_string(cls, value: str):
for member in cls:
if member.value == value:
return member
return None
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction) parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.") parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")

View File

@ -87,6 +87,7 @@ class IndexListCallbacks:
COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results" COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results"
EXECUTE_START = "execute_start" EXECUTE_START = "execute_start"
EXECUTE_CLEANUP = "execute_cleanup" EXECUTE_CLEANUP = "execute_cleanup"
RESIZE_COND_ITEM = "resize_cond_item"
def init_callbacks(self): def init_callbacks(self):
return {} return {}
@ -166,6 +167,18 @@ class IndexListContextHandler(ContextHandlerABC):
new_cond_item = cond_item.copy() new_cond_item = cond_item.copy()
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor) # when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
for cond_key, cond_value in new_cond_item.items(): for cond_key, cond_value in new_cond_item.items():
# Allow callbacks to handle custom conditioning items
handled = False
for callback in comfy.patcher_extension.get_all_callbacks(
IndexListCallbacks.RESIZE_COND_ITEM, self.callbacks
):
result = callback(cond_key, cond_value, window, x_in, device, new_cond_item)
if result is not None:
new_cond_item[cond_key] = result
handled = True
break
if handled:
continue
if isinstance(cond_value, torch.Tensor): if isinstance(cond_value, torch.Tensor):
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \ if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)): (cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):

View File

@ -1557,10 +1557,13 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
@torch.no_grad() @torch.no_grad()
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5): def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5, solver_type="phi_1"):
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2. """SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023) arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
""" """
if solver_type not in {"phi_1", "phi_2"}:
raise ValueError("solver_type must be 'phi_1' or 'phi_2'")
extra_args = {} if extra_args is None else extra_args extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None) seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
@ -1600,8 +1603,14 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args) denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 2 # Step 2
denoised_d = torch.lerp(denoised, denoised_2, fac) if solver_type == "phi_1":
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d denoised_d = torch.lerp(denoised, denoised_2, fac)
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
elif solver_type == "phi_2":
b2 = ei_h_phi_2(-h_eta) / r
b1 = ei_h_phi_1(-h_eta) - b2
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b2 * denoised_2)
if inject_noise: if inject_noise:
segment_factor = (r - 1) * h * eta segment_factor = (r - 1) * h * eta
sde_noise = sde_noise * segment_factor.exp() sde_noise = sde_noise * segment_factor.exp()

View File

@ -41,6 +41,11 @@ class ZImage_Control(torch.nn.Module):
ffn_dim_multiplier: float = (8.0 / 3.0), ffn_dim_multiplier: float = (8.0 / 3.0),
norm_eps: float = 1e-5, norm_eps: float = 1e-5,
qk_norm: bool = True, qk_norm: bool = True,
n_control_layers=6,
control_in_dim=16,
additional_in_dim=0,
broken=False,
refiner_control=False,
dtype=None, dtype=None,
device=None, device=None,
operations=None, operations=None,
@ -49,10 +54,11 @@ class ZImage_Control(torch.nn.Module):
super().__init__() super().__init__()
operation_settings = {"operations": operations, "device": device, "dtype": dtype} operation_settings = {"operations": operations, "device": device, "dtype": dtype}
self.additional_in_dim = 0 self.broken = broken
self.control_in_dim = 16 self.additional_in_dim = additional_in_dim
self.control_in_dim = control_in_dim
n_refiner_layers = 2 n_refiner_layers = 2
self.n_control_layers = 6 self.n_control_layers = n_control_layers
self.control_layers = nn.ModuleList( self.control_layers = nn.ModuleList(
[ [
ZImageControlTransformerBlock( ZImageControlTransformerBlock(
@ -74,28 +80,49 @@ class ZImage_Control(torch.nn.Module):
all_x_embedder = {} all_x_embedder = {}
patch_size = 2 patch_size = 2
f_patch_size = 1 f_patch_size = 1
x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True, device=device, dtype=dtype) x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * (self.control_in_dim + self.additional_in_dim), dim, bias=True, device=device, dtype=dtype)
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
self.refiner_control = refiner_control
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder) self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
self.control_noise_refiner = nn.ModuleList( if self.refiner_control:
[ self.control_noise_refiner = nn.ModuleList(
JointTransformerBlock( [
layer_id, ZImageControlTransformerBlock(
dim, layer_id,
n_heads, dim,
n_kv_heads, n_heads,
multiple_of, n_kv_heads,
ffn_dim_multiplier, multiple_of,
norm_eps, ffn_dim_multiplier,
qk_norm, norm_eps,
modulation=True, qk_norm,
z_image_modulation=True, block_id=layer_id,
operation_settings=operation_settings, operation_settings=operation_settings,
) )
for layer_id in range(n_refiner_layers) for layer_id in range(n_refiner_layers)
] ]
) )
else:
self.control_noise_refiner = nn.ModuleList(
[
JointTransformerBlock(
layer_id,
dim,
n_heads,
n_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
qk_norm,
modulation=True,
z_image_modulation=True,
operation_settings=operation_settings,
)
for layer_id in range(n_refiner_layers)
]
)
def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input): def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input):
patch_size = 2 patch_size = 2
@ -105,9 +132,29 @@ class ZImage_Control(torch.nn.Module):
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)) control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
x_attn_mask = None x_attn_mask = None
for layer in self.control_noise_refiner: if not self.refiner_control:
control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input) for layer in self.control_noise_refiner:
control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
return control_context return control_context
def forward_noise_refiner_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
if self.refiner_control:
if self.broken:
if layer_id == 0:
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
if layer_id > 0:
out = None
for i in range(1, len(self.control_layers)):
o, control_context = self.control_layers[i](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
if out is None:
out = o
return (out, control_context)
else:
return self.control_noise_refiner[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
else:
return (None, control_context)
def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input): def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input) return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)

View File

@ -536,6 +536,7 @@ class NextDiT(nn.Module):
bsz = len(x) bsz = len(x)
pH = pW = self.patch_size pH = pW = self.patch_size
device = x[0].device device = x[0].device
orig_x = x
if self.pad_tokens_multiple is not None: if self.pad_tokens_multiple is not None:
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
@ -572,13 +573,21 @@ class NextDiT(nn.Module):
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2) freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
patches = transformer_options.get("patches", {})
# refine context # refine context
for layer in self.context_refiner: for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options) cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
padded_img_mask = None padded_img_mask = None
for layer in self.noise_refiner: x_input = x
for i, layer in enumerate(self.noise_refiner):
x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options) x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
if "noise_refiner" in patches:
for p in patches["noise_refiner"]:
out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": freqs_cis[:, cap_pos_ids.shape[1]:], "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"})
if "img" in out:
x = out["img"]
padded_full_embed = torch.cat((cap_feats, x), dim=1) padded_full_embed = torch.cat((cap_feats, x), dim=1)
mask = None mask = None
@ -622,14 +631,18 @@ class NextDiT(nn.Module):
patches = transformer_options.get("patches", {}) patches = transformer_options.get("patches", {})
x_is_tensor = isinstance(x, torch.Tensor) x_is_tensor = isinstance(x, torch.Tensor)
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options) img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options)
freqs_cis = freqs_cis.to(img.device) freqs_cis = freqs_cis.to(img.device)
transformer_options["total_blocks"] = len(self.layers)
transformer_options["block_type"] = "double"
img_input = img
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
transformer_options["block_index"] = i
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options) img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
if "double_block" in patches: if "double_block" in patches:
for p in patches["double_block"]: for p in patches["double_block"]:
out = p({"img": img[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options}) out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
if "img" in out: if "img" in out:
img[:, cap_size[0]:] = out["img"] img[:, cap_size[0]:] = out["img"]
if "txt" in out: if "txt" in out:

View File

@ -218,9 +218,24 @@ class QwenImageTransformerBlock(nn.Module):
operations=operations, operations=operations,
) )
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def _apply_gate(self, x, y, gate, timestep_zero_index=None):
if timestep_zero_index is not None:
return y + torch.cat((x[:, :timestep_zero_index] * gate[0], x[:, timestep_zero_index:] * gate[1]), dim=1)
else:
return torch.addcmul(y, gate, x)
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor, timestep_zero_index=None) -> Tuple[torch.Tensor, torch.Tensor]:
shift, scale, gate = torch.chunk(mod_params, 3, dim=-1) shift, scale, gate = torch.chunk(mod_params, 3, dim=-1)
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1) if timestep_zero_index is not None:
actual_batch = shift.size(0) // 2
shift, shift_0 = shift[:actual_batch], shift[actual_batch:]
scale, scale_0 = scale[:actual_batch], scale[actual_batch:]
gate, gate_0 = gate[:actual_batch], gate[actual_batch:]
reg = torch.addcmul(shift.unsqueeze(1), x[:, :timestep_zero_index], 1 + scale.unsqueeze(1))
zero = torch.addcmul(shift_0.unsqueeze(1), x[:, timestep_zero_index:], 1 + scale_0.unsqueeze(1))
return torch.cat((reg, zero), dim=1), (gate.unsqueeze(1), gate_0.unsqueeze(1))
else:
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
def forward( def forward(
self, self,
@ -229,14 +244,19 @@ class QwenImageTransformerBlock(nn.Module):
encoder_hidden_states_mask: torch.Tensor, encoder_hidden_states_mask: torch.Tensor,
temb: torch.Tensor, temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
timestep_zero_index=None,
transformer_options={}, transformer_options={},
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod_params = self.img_mod(temb) img_mod_params = self.img_mod(temb)
if timestep_zero_index is not None:
temb = temb.chunk(2, dim=0)[0]
txt_mod_params = self.txt_mod(temb) txt_mod_params = self.txt_mod(temb)
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1) img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1, timestep_zero_index)
del img_mod1 del img_mod1
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1) txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
del txt_mod1 del txt_mod1
@ -251,15 +271,15 @@ class QwenImageTransformerBlock(nn.Module):
del img_modulated del img_modulated
del txt_modulated del txt_modulated
hidden_states = hidden_states + img_gate1 * img_attn_output hidden_states = self._apply_gate(img_attn_output, hidden_states, img_gate1, timestep_zero_index)
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
del img_attn_output del img_attn_output
del txt_attn_output del txt_attn_output
del img_gate1 del img_gate1
del txt_gate1 del txt_gate1
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2) img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2, timestep_zero_index)
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2)) hidden_states = self._apply_gate(self.img_mlp(img_modulated2), hidden_states, img_gate2, timestep_zero_index)
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2) txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2)) encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
@ -302,6 +322,7 @@ class QwenImageTransformer2DModel(nn.Module):
pooled_projection_dim: int = 768, pooled_projection_dim: int = 768,
guidance_embeds: bool = False, guidance_embeds: bool = False,
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
default_ref_method="index",
image_model=None, image_model=None,
final_layer=True, final_layer=True,
dtype=None, dtype=None,
@ -314,6 +335,7 @@ class QwenImageTransformer2DModel(nn.Module):
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels or in_channels self.out_channels = out_channels or in_channels
self.inner_dim = num_attention_heads * attention_head_dim self.inner_dim = num_attention_heads * attention_head_dim
self.default_ref_method = default_ref_method
self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope)) self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
@ -391,11 +413,14 @@ class QwenImageTransformer2DModel(nn.Module):
hidden_states, img_ids, orig_shape = self.process_img(x) hidden_states, img_ids, orig_shape = self.process_img(x)
num_embeds = hidden_states.shape[1] num_embeds = hidden_states.shape[1]
timestep_zero_index = None
if ref_latents is not None: if ref_latents is not None:
h = 0 h = 0
w = 0 w = 0
index = 0 index = 0
index_ref_method = kwargs.get("ref_latents_method", "index") == "index" ref_method = kwargs.get("ref_latents_method", self.default_ref_method)
index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero")
timestep_zero = ref_method == "index_timestep_zero"
for ref in ref_latents: for ref in ref_latents:
if index_ref_method: if index_ref_method:
index += 1 index += 1
@ -415,6 +440,10 @@ class QwenImageTransformer2DModel(nn.Module):
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
hidden_states = torch.cat([hidden_states, kontext], dim=1) hidden_states = torch.cat([hidden_states, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1)
if timestep_zero:
if index > 0:
timestep = torch.cat([timestep, timestep * 0], dim=0)
timestep_zero_index = num_embeds
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
@ -446,7 +475,7 @@ class QwenImageTransformer2DModel(nn.Module):
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"]) out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], timestep_zero_index=timestep_zero_index, transformer_options=args["transformer_options"])
return out return out
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap}) out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
hidden_states = out["img"] hidden_states = out["img"]
@ -458,6 +487,7 @@ class QwenImageTransformer2DModel(nn.Module):
encoder_hidden_states_mask=encoder_hidden_states_mask, encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb, temb=temb,
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
timestep_zero_index=timestep_zero_index,
transformer_options=transformer_options, transformer_options=transformer_options,
) )
@ -474,6 +504,9 @@ class QwenImageTransformer2DModel(nn.Module):
if add is not None: if add is not None:
hidden_states[:, :add.shape[1]] += add hidden_states[:, :add.shape[1]] += add
if timestep_zero_index is not None:
temb = temb.chunk(2, dim=0)[0]
hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states) hidden_states = self.proj_out(hidden_states)

View File

@ -568,7 +568,10 @@ class WanModel(torch.nn.Module):
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks): for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
@ -763,7 +766,10 @@ class VaceWanModel(WanModel):
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks): for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
@ -862,7 +868,10 @@ class CameraWanModel(WanModel):
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks): for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
@ -1326,16 +1335,19 @@ class WanModel_S2V(WanModel):
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks): for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"]) out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], transformer_options=args["transformer_options"])
return out return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap}) out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"] x = out["img"]
else: else:
x = block(x, e=e0, freqs=freqs, context=context) x = block(x, e=e0, freqs=freqs, context=context, transformer_options=transformer_options)
if audio_emb is not None: if audio_emb is not None:
x = self.audio_injector(x, i, audio_emb, audio_emb_global, seq_len) x = self.audio_injector(x, i, audio_emb, audio_emb_global, seq_len)
# head # head
@ -1574,7 +1586,10 @@ class HumoWanModel(WanModel):
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks): for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}

View File

@ -523,7 +523,10 @@ class AnimateWanModel(WanModel):
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks): for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}

View File

@ -259,8 +259,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["nerf_tile_size"] = 512 dit_config["nerf_tile_size"] = 512
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear" dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
dit_config["nerf_embedder_dtype"] = torch.float32 dit_config["nerf_embedder_dtype"] = torch.float32
if "__x0__" in state_dict_keys: # x0 pred if "__x0__" in state_dict_keys: # x0 pred
dit_config["use_x0"] = True dit_config["use_x0"] = True
else:
dit_config["use_x0"] = False
else: else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys

View File

@ -454,6 +454,9 @@ class ModelPatcher:
def set_model_post_input_patch(self, patch): def set_model_post_input_patch(self, patch):
self.set_model_patch(patch, "post_input") self.set_model_patch(patch, "post_input")
def set_model_noise_refiner_patch(self, patch):
self.set_model_patch(patch, "noise_refiner")
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs): def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
rope_options = self.model_options["transformer_options"].get("rope_options", {}) rope_options = self.model_options["transformer_options"].get("rope_options", {})
rope_options["scale_x"] = scale_x rope_options["scale_x"] = scale_x

View File

@ -497,15 +497,14 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
) -> None: ) -> None:
super().__init__() super().__init__()
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype} if dtype is None:
# self.factory_kwargs = {"device": device, "dtype": dtype} dtype = MixedPrecisionOps._compute_dtype
self.factory_kwargs = {"device": device, "dtype": dtype}
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
if bias: self._has_bias = bias
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
else:
self.register_parameter("bias", None)
self.tensor_class = None self.tensor_class = None
self._full_precision_mm = MixedPrecisionOps._full_precision_mm self._full_precision_mm = MixedPrecisionOps._full_precision_mm
@ -530,7 +529,14 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
layer_conf = json.loads(layer_conf.numpy().tobytes()) layer_conf = json.loads(layer_conf.numpy().tobytes())
if layer_conf is None: if layer_conf is None:
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) dtype = self.factory_kwargs["dtype"]
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=dtype), requires_grad=False)
if dtype != MixedPrecisionOps._compute_dtype:
self.comfy_cast_weights = True
if self._has_bias:
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=dtype))
else:
self.register_parameter("bias", None)
else: else:
self.quant_format = layer_conf.get("format", None) self.quant_format = layer_conf.get("format", None)
if not self._full_precision_mm: if not self._full_precision_mm:
@ -560,6 +566,11 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
requires_grad=False requires_grad=False
) )
if self._has_bias:
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=MixedPrecisionOps._compute_dtype))
else:
self.register_parameter("bias", None)
for param_name in qconfig["parameters"]: for param_name in qconfig["parameters"]:
param_key = f"{prefix}{param_name}" param_key = f"{prefix}{param_name}"
_v = state_dict.pop(param_key, None) _v = state_dict.pop(param_key, None)
@ -581,7 +592,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
quant_conf = {"format": self.quant_format} quant_conf = {"format": self.quant_format}
if self._full_precision_mm: if self._full_precision_mm:
quant_conf["full_precision_matrix_mult"] = True quant_conf["full_precision_matrix_mult"] = True
sd["{}comfy_quant".format(prefix)] = torch.frombuffer(json.dumps(quant_conf).encode('utf-8'), dtype=torch.uint8) sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
return sd return sd
def _forward(self, input, weight, bias): def _forward(self, input, weight, bias):

View File

@ -549,8 +549,10 @@ class VAE:
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0} ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig) self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (2200 if shape[2]<=4 else 7000) * shape[3] * shape[4] * (8*8) * model_management.dtype_size(dtype)
# Hunyuan 3d v2 2.0 & 2.1 # Hunyuan 3d v2 2.0 & 2.1
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd: elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:

View File

@ -28,6 +28,7 @@ from . import supported_models_base
from . import latent_formats from . import latent_formats
from . import diffusers_convert from . import diffusers_convert
import comfy.model_management
class SD15(supported_models_base.BASE): class SD15(supported_models_base.BASE):
unet_config = { unet_config = {
@ -541,7 +542,7 @@ class SD3(supported_models_base.BASE):
unet_extra_config = {} unet_extra_config = {}
latent_format = latent_formats.SD3 latent_format = latent_formats.SD3
memory_usage_factor = 1.2 memory_usage_factor = 1.6
text_encoder_key_prefix = ["text_encoders."] text_encoder_key_prefix = ["text_encoders."]
@ -965,7 +966,7 @@ class CosmosT2IPredict2(supported_models_base.BASE):
def __init__(self, unet_config): def __init__(self, unet_config):
super().__init__(unet_config) super().__init__(unet_config)
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.9 self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.95
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
out = model_base.CosmosPredict2(self, device=device) out = model_base.CosmosPredict2(self, device=device)
@ -1026,9 +1027,15 @@ class ZImage(Lumina2):
"shift": 3.0, "shift": 3.0,
} }
memory_usage_factor = 1.7 memory_usage_factor = 2.0
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] supported_inference_dtypes = [torch.bfloat16, torch.float32]
def __init__(self, unet_config):
super().__init__(unet_config)
if comfy.model_management.extended_fp16_support():
self.supported_inference_dtypes = self.supported_inference_dtypes.copy()
self.supported_inference_dtypes.insert(1, torch.float16)
def clip_target(self, state_dict={}): def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0] pref = self.text_encoder_key_prefix[0]
@ -1289,7 +1296,7 @@ class ChromaRadiance(Chroma):
latent_format = comfy.latent_formats.ChromaRadiance latent_format = comfy.latent_formats.ChromaRadiance
# Pixel-space model, no spatial compression for model input. # Pixel-space model, no spatial compression for model input.
memory_usage_factor = 0.038 memory_usage_factor = 0.044
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
return model_base.ChromaRadiance(self, device=device) return model_base.ChromaRadiance(self, device=device)
@ -1332,7 +1339,7 @@ class Omnigen2(supported_models_base.BASE):
"shift": 2.6, "shift": 2.6,
} }
memory_usage_factor = 1.65 #TODO memory_usage_factor = 1.95 #TODO
unet_extra_config = {} unet_extra_config = {}
latent_format = latent_formats.Flux latent_format = latent_formats.Flux
@ -1397,7 +1404,7 @@ class HunyuanImage21(HunyuanVideo):
latent_format = latent_formats.HunyuanImage21 latent_format = latent_formats.HunyuanImage21
memory_usage_factor = 7.7 memory_usage_factor = 8.7
supported_inference_dtypes = [torch.bfloat16, torch.float32] supported_inference_dtypes = [torch.bfloat16, torch.float32]
@ -1488,7 +1495,7 @@ class Kandinsky5(supported_models_base.BASE):
unet_extra_config = {} unet_extra_config = {}
latent_format = latent_formats.HunyuanVideo latent_format = latent_formats.HunyuanVideo
memory_usage_factor = 1.1 #TODO memory_usage_factor = 1.25 #TODO
supported_inference_dtypes = [torch.bfloat16, torch.float32] supported_inference_dtypes = [torch.bfloat16, torch.float32]
@ -1517,7 +1524,7 @@ class Kandinsky5Image(Kandinsky5):
} }
latent_format = latent_formats.Flux latent_format = latent_formats.Flux
memory_usage_factor = 1.1 #TODO memory_usage_factor = 1.25 #TODO
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
out = model_base.Kandinsky5Image(self, device=device) out = model_base.Kandinsky5Image(self, device=device)

View File

@ -53,7 +53,7 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
ALWAYS_SAFE_LOAD = True ALWAYS_SAFE_LOAD = True
logging.info("Checkpoint files will always be loaded safely.") logging.info("Checkpoint files will always be loaded safely.")
else: else:
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.") logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.")
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if device is None: if device is None:
@ -1262,6 +1262,6 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
if quant_metadata is not None: if quant_metadata is not None:
layers = quant_metadata["layers"] layers = quant_metadata["layers"]
for k, v in layers.items(): for k, v in layers.items():
state_dict["{}.comfy_quant".format(k)] = torch.frombuffer(json.dumps(v).encode('utf-8'), dtype=torch.uint8) state_dict["{}.comfy_quant".format(k)] = torch.tensor(list(json.dumps(v).encode('utf-8')), dtype=torch.uint8)
return state_dict, metadata return state_dict, metadata

View File

@ -5,12 +5,12 @@ This module handles capability negotiation between frontend and backend,
allowing graceful protocol evolution while maintaining backward compatibility. allowing graceful protocol evolution while maintaining backward compatibility.
""" """
from typing import Any, Dict from typing import Any
from comfy.cli_args import args from comfy.cli_args import args
# Default server capabilities # Default server capabilities
SERVER_FEATURE_FLAGS: Dict[str, Any] = { SERVER_FEATURE_FLAGS: dict[str, Any] = {
"supports_preview_metadata": True, "supports_preview_metadata": True,
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes "max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
"extension": {"manager": {"supports_v4": True}}, "extension": {"manager": {"supports_v4": True}},
@ -18,7 +18,7 @@ SERVER_FEATURE_FLAGS: Dict[str, Any] = {
def get_connection_feature( def get_connection_feature(
sockets_metadata: Dict[str, Dict[str, Any]], sockets_metadata: dict[str, dict[str, Any]],
sid: str, sid: str,
feature_name: str, feature_name: str,
default: Any = False default: Any = False
@ -42,7 +42,7 @@ def get_connection_feature(
def supports_feature( def supports_feature(
sockets_metadata: Dict[str, Dict[str, Any]], sockets_metadata: dict[str, dict[str, Any]],
sid: str, sid: str,
feature_name: str feature_name: str
) -> bool: ) -> bool:
@ -60,7 +60,7 @@ def supports_feature(
return get_connection_feature(sockets_metadata, sid, feature_name, False) is True return get_connection_feature(sockets_metadata, sid, feature_name, False) is True
def get_server_features() -> Dict[str, Any]: def get_server_features() -> dict[str, Any]:
""" """
Get the server's feature flags. Get the server's feature flags.

View File

@ -1,4 +1,4 @@
from typing import Type, List, NamedTuple from typing import NamedTuple
from comfy_api.internal.singleton import ProxiedSingleton from comfy_api.internal.singleton import ProxiedSingleton
from packaging import version as packaging_version from packaging import version as packaging_version
@ -10,7 +10,7 @@ class ComfyAPIBase(ProxiedSingleton):
class ComfyAPIWithVersion(NamedTuple): class ComfyAPIWithVersion(NamedTuple):
version: str version: str
api_class: Type[ComfyAPIBase] api_class: type[ComfyAPIBase]
def parse_version(version_str: str) -> packaging_version.Version: def parse_version(version_str: str) -> packaging_version.Version:
@ -23,16 +23,16 @@ def parse_version(version_str: str) -> packaging_version.Version:
return packaging_version.parse(version_str) return packaging_version.parse(version_str)
registered_versions: List[ComfyAPIWithVersion] = [] registered_versions: list[ComfyAPIWithVersion] = []
def register_versions(versions: List[ComfyAPIWithVersion]): def register_versions(versions: list[ComfyAPIWithVersion]):
versions.sort(key=lambda x: parse_version(x.version)) versions.sort(key=lambda x: parse_version(x.version))
global registered_versions global registered_versions
registered_versions = versions registered_versions = versions
def get_all_versions() -> List[ComfyAPIWithVersion]: def get_all_versions() -> list[ComfyAPIWithVersion]:
""" """
Returns a list of all registered ComfyAPI versions. Returns a list of all registered ComfyAPI versions.
""" """

View File

@ -8,7 +8,7 @@ import os
import textwrap import textwrap
import threading import threading
from enum import Enum from enum import Enum
from typing import Optional, Type, get_origin, get_args, get_type_hints from typing import Optional, get_origin, get_args, get_type_hints
class TypeTracker: class TypeTracker:
@ -193,7 +193,7 @@ class AsyncToSyncConverter:
return result_container["result"] return result_container["result"]
@classmethod @classmethod
def create_sync_class(cls, async_class: Type, thread_pool_size=10) -> Type: def create_sync_class(cls, async_class: type, thread_pool_size=10) -> type:
""" """
Creates a new class with synchronous versions of all async methods. Creates a new class with synchronous versions of all async methods.
@ -563,7 +563,7 @@ class AsyncToSyncConverter:
@classmethod @classmethod
def _generate_imports( def _generate_imports(
cls, async_class: Type, type_tracker: TypeTracker cls, async_class: type, type_tracker: TypeTracker
) -> list[str]: ) -> list[str]:
"""Generate import statements for the stub file.""" """Generate import statements for the stub file."""
imports = [] imports = []
@ -628,7 +628,7 @@ class AsyncToSyncConverter:
return imports return imports
@classmethod @classmethod
def _get_class_attributes(cls, async_class: Type) -> list[tuple[str, Type]]: def _get_class_attributes(cls, async_class: type) -> list[tuple[str, type]]:
"""Extract class attributes that are classes themselves.""" """Extract class attributes that are classes themselves."""
class_attributes = [] class_attributes = []
@ -654,7 +654,7 @@ class AsyncToSyncConverter:
def _generate_inner_class_stub( def _generate_inner_class_stub(
cls, cls,
name: str, name: str,
attr: Type, attr: type,
indent: str = " ", indent: str = " ",
type_tracker: Optional[TypeTracker] = None, type_tracker: Optional[TypeTracker] = None,
) -> list[str]: ) -> list[str]:
@ -782,7 +782,7 @@ class AsyncToSyncConverter:
return processed return processed
@classmethod @classmethod
def generate_stub_file(cls, async_class: Type, sync_class: Type) -> None: def generate_stub_file(cls, async_class: type, sync_class: type) -> None:
""" """
Generate a .pyi stub file for the sync class to help IDEs with type checking. Generate a .pyi stub file for the sync class to help IDEs with type checking.
""" """
@ -988,7 +988,7 @@ class AsyncToSyncConverter:
logging.error(traceback.format_exc()) logging.error(traceback.format_exc())
def create_sync_class(async_class: Type, thread_pool_size=10) -> Type: def create_sync_class(async_class: type, thread_pool_size=10) -> type:
""" """
Creates a sync version of an async class Creates a sync version of an async class

View File

@ -1,4 +1,4 @@
from typing import Type, TypeVar from typing import TypeVar
class SingletonMetaclass(type): class SingletonMetaclass(type):
T = TypeVar("T", bound="SingletonMetaclass") T = TypeVar("T", bound="SingletonMetaclass")
@ -11,13 +11,13 @@ class SingletonMetaclass(type):
) )
return cls._instances[cls] return cls._instances[cls]
def inject_instance(cls: Type[T], instance: T) -> None: def inject_instance(cls: type[T], instance: T) -> None:
assert cls not in SingletonMetaclass._instances, ( assert cls not in SingletonMetaclass._instances, (
"Cannot inject instance after first instantiation" "Cannot inject instance after first instantiation"
) )
SingletonMetaclass._instances[cls] = instance SingletonMetaclass._instances[cls] = instance
def get_instance(cls: Type[T], *args, **kwargs) -> T: def get_instance(cls: type[T], *args, **kwargs) -> T:
""" """
Gets the singleton instance of the class, creating it if it doesn't exist. Gets the singleton instance of the class, creating it if it doesn't exist.
""" """

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Type, TYPE_CHECKING from typing import TYPE_CHECKING
from comfy_api.internal import ComfyAPIBase from comfy_api.internal import ComfyAPIBase
from comfy_api.internal.singleton import ProxiedSingleton from comfy_api.internal.singleton import ProxiedSingleton
from comfy_api.internal.async_to_sync import create_sync_class from comfy_api.internal.async_to_sync import create_sync_class
@ -113,7 +113,7 @@ ComfyAPI = ComfyAPI_latest
if TYPE_CHECKING: if TYPE_CHECKING:
import comfy_api.latest.generated.ComfyAPISyncStub # type: ignore import comfy_api.latest.generated.ComfyAPISyncStub # type: ignore
ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub] ComfyAPISync: type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
ComfyAPISync = create_sync_class(ComfyAPI_latest) ComfyAPISync = create_sync_class(ComfyAPI_latest)
# create new aliases for io and ui # create new aliases for io and ui

View File

@ -1,5 +1,5 @@
import torch import torch
from typing import TypedDict, List, Optional from typing import TypedDict, Optional
ImageInput = torch.Tensor ImageInput = torch.Tensor
""" """
@ -39,4 +39,4 @@ class LatentInput(TypedDict):
Optional noise mask tensor in the same format as samples. Optional noise mask tensor in the same format as samples.
""" """
batch_index: Optional[List[int]] batch_index: Optional[list[int]]

View File

@ -774,6 +774,13 @@ class AudioEncoder(ComfyTypeIO):
class AudioEncoderOutput(ComfyTypeIO): class AudioEncoderOutput(ComfyTypeIO):
Type = Any Type = Any
@comfytype(io_type="TRACKS")
class Tracks(ComfyTypeIO):
class TrackDict(TypedDict):
track_path: torch.Tensor
track_visibility: torch.Tensor
Type = TrackDict
@comfytype(io_type="COMFY_MULTITYPED_V3") @comfytype(io_type="COMFY_MULTITYPED_V3")
class MultiType: class MultiType:
Type = Any Type = Any
@ -1815,7 +1822,7 @@ class NodeOutput(_NodeOutputInternal):
ui = data["ui"] ui = data["ui"]
if "expand" in data: if "expand" in data:
expand = data["expand"] expand = data["expand"]
return cls(args=args, ui=ui, expand=expand) return cls(*args, ui=ui, expand=expand)
def __getitem__(self, index) -> Any: def __getitem__(self, index) -> Any:
return self.args[index] return self.args[index]
@ -1894,6 +1901,7 @@ __all__ = [
"SEGS", "SEGS",
"AnyType", "AnyType",
"MultiType", "MultiType",
"Tracks",
# Dynamic Types # Dynamic Types
"MatchType", "MatchType",
# "DynamicCombo", # "DynamicCombo",

View File

@ -5,7 +5,6 @@ import os
import random import random
import uuid import uuid
from io import BytesIO from io import BytesIO
from typing import Type
import av import av
import numpy as np import numpy as np
@ -83,7 +82,7 @@ class ImageSaveHelper:
return PILImage.fromarray(np.clip(255.0 * image_tensor.cpu().numpy(), 0, 255).astype(np.uint8)) return PILImage.fromarray(np.clip(255.0 * image_tensor.cpu().numpy(), 0, 255).astype(np.uint8))
@staticmethod @staticmethod
def _create_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None: def _create_png_metadata(cls: type[ComfyNode] | None) -> PngInfo | None:
"""Creates a PngInfo object with prompt and extra_pnginfo.""" """Creates a PngInfo object with prompt and extra_pnginfo."""
if args.disable_metadata or cls is None or not cls.hidden: if args.disable_metadata or cls is None or not cls.hidden:
return None return None
@ -96,7 +95,7 @@ class ImageSaveHelper:
return metadata return metadata
@staticmethod @staticmethod
def _create_animated_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None: def _create_animated_png_metadata(cls: type[ComfyNode] | None) -> PngInfo | None:
"""Creates a PngInfo object with prompt and extra_pnginfo for animated PNGs (APNG).""" """Creates a PngInfo object with prompt and extra_pnginfo for animated PNGs (APNG)."""
if args.disable_metadata or cls is None or not cls.hidden: if args.disable_metadata or cls is None or not cls.hidden:
return None return None
@ -121,7 +120,7 @@ class ImageSaveHelper:
return metadata return metadata
@staticmethod @staticmethod
def _create_webp_metadata(pil_image: PILImage.Image, cls: Type[ComfyNode] | None) -> PILImage.Exif: def _create_webp_metadata(pil_image: PILImage.Image, cls: type[ComfyNode] | None) -> PILImage.Exif:
"""Creates EXIF metadata bytes for WebP images.""" """Creates EXIF metadata bytes for WebP images."""
exif_data = pil_image.getexif() exif_data = pil_image.getexif()
if args.disable_metadata or cls is None or cls.hidden is None: if args.disable_metadata or cls is None or cls.hidden is None:
@ -137,7 +136,7 @@ class ImageSaveHelper:
@staticmethod @staticmethod
def save_images( def save_images(
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, compress_level = 4, images, filename_prefix: str, folder_type: FolderType, cls: type[ComfyNode] | None, compress_level = 4,
) -> list[SavedResult]: ) -> list[SavedResult]:
"""Saves a batch of images as individual PNG files.""" """Saves a batch of images as individual PNG files."""
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path( full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
@ -155,7 +154,7 @@ class ImageSaveHelper:
return results return results
@staticmethod @staticmethod
def get_save_images_ui(images, filename_prefix: str, cls: Type[ComfyNode] | None, compress_level=4) -> SavedImages: def get_save_images_ui(images, filename_prefix: str, cls: type[ComfyNode] | None, compress_level=4) -> SavedImages:
"""Saves a batch of images and returns a UI object for the node output.""" """Saves a batch of images and returns a UI object for the node output."""
return SavedImages( return SavedImages(
ImageSaveHelper.save_images( ImageSaveHelper.save_images(
@ -169,7 +168,7 @@ class ImageSaveHelper:
@staticmethod @staticmethod
def save_animated_png( def save_animated_png(
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, fps: float, compress_level: int images, filename_prefix: str, folder_type: FolderType, cls: type[ComfyNode] | None, fps: float, compress_level: int
) -> SavedResult: ) -> SavedResult:
"""Saves a batch of images as a single animated PNG.""" """Saves a batch of images as a single animated PNG."""
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path( full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
@ -191,7 +190,7 @@ class ImageSaveHelper:
@staticmethod @staticmethod
def get_save_animated_png_ui( def get_save_animated_png_ui(
images, filename_prefix: str, cls: Type[ComfyNode] | None, fps: float, compress_level: int images, filename_prefix: str, cls: type[ComfyNode] | None, fps: float, compress_level: int
) -> SavedImages: ) -> SavedImages:
"""Saves an animated PNG and returns a UI object for the node output.""" """Saves an animated PNG and returns a UI object for the node output."""
result = ImageSaveHelper.save_animated_png( result = ImageSaveHelper.save_animated_png(
@ -209,7 +208,7 @@ class ImageSaveHelper:
images, images,
filename_prefix: str, filename_prefix: str,
folder_type: FolderType, folder_type: FolderType,
cls: Type[ComfyNode] | None, cls: type[ComfyNode] | None,
fps: float, fps: float,
lossless: bool, lossless: bool,
quality: int, quality: int,
@ -238,7 +237,7 @@ class ImageSaveHelper:
def get_save_animated_webp_ui( def get_save_animated_webp_ui(
images, images,
filename_prefix: str, filename_prefix: str,
cls: Type[ComfyNode] | None, cls: type[ComfyNode] | None,
fps: float, fps: float,
lossless: bool, lossless: bool,
quality: int, quality: int,
@ -267,7 +266,7 @@ class AudioSaveHelper:
audio: dict, audio: dict,
filename_prefix: str, filename_prefix: str,
folder_type: FolderType, folder_type: FolderType,
cls: Type[ComfyNode] | None, cls: type[ComfyNode] | None,
format: str = "flac", format: str = "flac",
quality: str = "128k", quality: str = "128k",
) -> list[SavedResult]: ) -> list[SavedResult]:
@ -372,7 +371,7 @@ class AudioSaveHelper:
@staticmethod @staticmethod
def get_save_audio_ui( def get_save_audio_ui(
audio, filename_prefix: str, cls: Type[ComfyNode] | None, format: str = "flac", quality: str = "128k", audio, filename_prefix: str, cls: type[ComfyNode] | None, format: str = "flac", quality: str = "128k",
) -> SavedAudios: ) -> SavedAudios:
"""Save and instantly wrap for UI.""" """Save and instantly wrap for UI."""
return SavedAudios( return SavedAudios(
@ -388,7 +387,7 @@ class AudioSaveHelper:
class PreviewImage(_UIOutput): class PreviewImage(_UIOutput):
def __init__(self, image: Image.Type, animated: bool = False, cls: Type[ComfyNode] = None, **kwargs): def __init__(self, image: Image.Type, animated: bool = False, cls: type[ComfyNode] = None, **kwargs):
self.values = ImageSaveHelper.save_images( self.values = ImageSaveHelper.save_images(
image, image,
filename_prefix="ComfyUI_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)), filename_prefix="ComfyUI_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)),
@ -412,7 +411,7 @@ class PreviewMask(PreviewImage):
class PreviewAudio(_UIOutput): class PreviewAudio(_UIOutput):
def __init__(self, audio: dict, cls: Type[ComfyNode] = None, **kwargs): def __init__(self, audio: dict, cls: type[ComfyNode] = None, **kwargs):
self.values = AudioSaveHelper.save_audio( self.values = AudioSaveHelper.save_audio(
audio, audio,
filename_prefix="ComfyUI_temp_" + "".join(random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(5)), filename_prefix="ComfyUI_temp_" + "".join(random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(5)),

View File

@ -2,9 +2,8 @@ from comfy_api.latest import ComfyAPI_latest
from comfy_api.v0_0_2 import ComfyAPIAdapter_v0_0_2 from comfy_api.v0_0_2 import ComfyAPIAdapter_v0_0_2
from comfy_api.v0_0_1 import ComfyAPIAdapter_v0_0_1 from comfy_api.v0_0_1 import ComfyAPIAdapter_v0_0_1
from comfy_api.internal import ComfyAPIBase from comfy_api.internal import ComfyAPIBase
from typing import List, Type
supported_versions: List[Type[ComfyAPIBase]] = [ supported_versions: list[type[ComfyAPIBase]] = [
ComfyAPI_latest, ComfyAPI_latest,
ComfyAPIAdapter_v0_0_2, ComfyAPIAdapter_v0_0_2,
ComfyAPIAdapter_v0_0_1, ComfyAPIAdapter_v0_0_1,

View File

@ -51,25 +51,25 @@ class TaskStatusImageResult(BaseModel):
url: str = Field(..., description="URL for generated image") url: str = Field(..., description="URL for generated image")
class OmniTaskStatusResults(BaseModel): class TaskStatusResults(BaseModel):
videos: list[TaskStatusVideoResult] | None = Field(None) videos: list[TaskStatusVideoResult] | None = Field(None)
images: list[TaskStatusImageResult] | None = Field(None) images: list[TaskStatusImageResult] | None = Field(None)
class OmniTaskStatusResponseData(BaseModel): class TaskStatusResponseData(BaseModel):
created_at: int | None = Field(None, description="Task creation time") created_at: int | None = Field(None, description="Task creation time")
updated_at: int | None = Field(None, description="Task update time") updated_at: int | None = Field(None, description="Task update time")
task_status: str | None = None task_status: str | None = None
task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.") task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.")
task_id: str | None = Field(None, description="Task ID") task_id: str | None = Field(None, description="Task ID")
task_result: OmniTaskStatusResults | None = Field(None) task_result: TaskStatusResults | None = Field(None)
class OmniTaskStatusResponse(BaseModel): class TaskStatusResponse(BaseModel):
code: int | None = Field(None, description="Error code") code: int | None = Field(None, description="Error code")
message: str | None = Field(None, description="Error message") message: str | None = Field(None, description="Error message")
request_id: str | None = Field(None, description="Request ID") request_id: str | None = Field(None, description="Request ID")
data: OmniTaskStatusResponseData | None = Field(None) data: TaskStatusResponseData | None = Field(None)
class OmniImageParamImage(BaseModel): class OmniImageParamImage(BaseModel):
@ -84,3 +84,21 @@ class OmniProImageRequest(BaseModel):
mode: str = Field("pro") mode: str = Field("pro")
n: int | None = Field(1, le=9) n: int | None = Field(1, le=9)
image_list: list[OmniImageParamImage] | None = Field(..., max_length=10) image_list: list[OmniImageParamImage] | None = Field(..., max_length=10)
class TextToVideoWithAudioRequest(BaseModel):
model_name: str = Field(..., description="kling-v2-6")
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
duration: str = Field(..., description="'5' or '10'")
prompt: str = Field(...)
mode: str = Field("pro")
sound: str = Field(..., description="'on' or 'off'")
class ImageToVideoWithAudioRequest(BaseModel):
model_name: str = Field(..., description="kling-v2-6")
image: str = Field(...)
duration: str = Field(..., description="'5' or '10'")
prompt: str = Field(...)
mode: str = Field("pro")
sound: str = Field(..., description="'on' or 'off'")

View File

@ -1,100 +0,0 @@
from typing import Optional
from enum import Enum
from pydantic import BaseModel, Field
class Pikaffect(str, Enum):
Cake_ify = "Cake-ify"
Crumble = "Crumble"
Crush = "Crush"
Decapitate = "Decapitate"
Deflate = "Deflate"
Dissolve = "Dissolve"
Explode = "Explode"
Eye_pop = "Eye-pop"
Inflate = "Inflate"
Levitate = "Levitate"
Melt = "Melt"
Peel = "Peel"
Poke = "Poke"
Squish = "Squish"
Ta_da = "Ta-da"
Tear = "Tear"
class PikaBodyGenerate22C2vGenerate22PikascenesPost(BaseModel):
aspectRatio: Optional[float] = Field(None, description='Aspect ratio (width / height)')
duration: Optional[int] = Field(5)
ingredientsMode: str = Field(...)
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
resolution: Optional[str] = Field('1080p')
seed: Optional[int] = Field(None)
class PikaGenerateResponse(BaseModel):
video_id: str = Field(...)
class PikaBodyGenerate22I2vGenerate22I2vPost(BaseModel):
duration: Optional[int] = 5
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGenerate22KeyframeGenerate22PikaframesPost(BaseModel):
duration: Optional[int] = Field(None, ge=5, le=10)
negativePrompt: Optional[str] = Field(None)
promptText: str = Field(...)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGenerate22T2vGenerate22T2vPost(BaseModel):
aspectRatio: Optional[float] = Field(
1.7777777777777777,
description='Aspect ratio (width / height)',
ge=0.4,
le=2.5,
)
duration: Optional[int] = 5
negativePrompt: Optional[str] = Field(None)
promptText: str = Field(...)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikadditionsGeneratePikadditionsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikaffectsGeneratePikaffectsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
pikaffect: Optional[str] = None
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikaswapsGeneratePikaswapsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
modifyRegionRoi: Optional[str] = Field(None)
class PikaStatusEnum(str, Enum):
queued = "queued"
started = "started"
finished = "finished"
failed = "failed"
class PikaVideoResponse(BaseModel):
id: str = Field(...)
progress: Optional[int] = Field(None)
status: PikaStatusEnum
url: Optional[str] = Field(None)

View File

@ -5,11 +5,17 @@ from typing import Optional, List, Dict, Any, Union
from pydantic import BaseModel, Field, RootModel from pydantic import BaseModel, Field, RootModel
class TripoModelVersion(str, Enum): class TripoModelVersion(str, Enum):
v3_0_20250812 = 'v3.0-20250812'
v2_5_20250123 = 'v2.5-20250123' v2_5_20250123 = 'v2.5-20250123'
v2_0_20240919 = 'v2.0-20240919' v2_0_20240919 = 'v2.0-20240919'
v1_4_20240625 = 'v1.4-20240625' v1_4_20240625 = 'v1.4-20240625'
class TripoGeometryQuality(str, Enum):
standard = 'standard'
detailed = 'detailed'
class TripoTextureQuality(str, Enum): class TripoTextureQuality(str, Enum):
standard = 'standard' standard = 'standard'
detailed = 'detailed' detailed = 'detailed'
@ -61,14 +67,20 @@ class TripoSpec(str, Enum):
class TripoAnimation(str, Enum): class TripoAnimation(str, Enum):
IDLE = "preset:idle" IDLE = "preset:idle"
WALK = "preset:walk" WALK = "preset:walk"
RUN = "preset:run"
DIVE = "preset:dive"
CLIMB = "preset:climb" CLIMB = "preset:climb"
JUMP = "preset:jump" JUMP = "preset:jump"
RUN = "preset:run"
SLASH = "preset:slash" SLASH = "preset:slash"
SHOOT = "preset:shoot" SHOOT = "preset:shoot"
HURT = "preset:hurt" HURT = "preset:hurt"
FALL = "preset:fall" FALL = "preset:fall"
TURN = "preset:turn" TURN = "preset:turn"
QUADRUPED_WALK = "preset:quadruped:walk"
HEXAPOD_WALK = "preset:hexapod:walk"
OCTOPOD_WALK = "preset:octopod:walk"
SERPENTINE_MARCH = "preset:serpentine:march"
AQUATIC_MARCH = "preset:aquatic:march"
class TripoStylizeStyle(str, Enum): class TripoStylizeStyle(str, Enum):
LEGO = "lego" LEGO = "lego"
@ -105,6 +117,11 @@ class TripoTaskStatus(str, Enum):
BANNED = "banned" BANNED = "banned"
EXPIRED = "expired" EXPIRED = "expired"
class TripoFbxPreset(str, Enum):
BLENDER = "blender"
MIXAMO = "mixamo"
_3DSMAX = "3dsmax"
class TripoFileTokenReference(BaseModel): class TripoFileTokenReference(BaseModel):
type: Optional[str] = Field(None, description='The type of the reference') type: Optional[str] = Field(None, description='The type of the reference')
file_token: str file_token: str
@ -142,6 +159,7 @@ class TripoTextToModelRequest(BaseModel):
model_seed: Optional[int] = Field(None, description='The seed for the model') model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture') texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
style: Optional[TripoStyle] = None style: Optional[TripoStyle] = None
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model') auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model') quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
@ -156,6 +174,7 @@ class TripoImageToModelRequest(BaseModel):
model_seed: Optional[int] = Field(None, description='The seed for the model') model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture') texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method') texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method')
style: Optional[TripoStyle] = Field(None, description='The style to apply to the generated model') style: Optional[TripoStyle] = Field(None, description='The style to apply to the generated model')
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model') auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
@ -173,6 +192,7 @@ class TripoMultiviewToModelRequest(BaseModel):
model_seed: Optional[int] = Field(None, description='The seed for the model') model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture') texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
texture_alignment: Optional[TripoTextureAlignment] = TripoTextureAlignment.ORIGINAL_IMAGE texture_alignment: Optional[TripoTextureAlignment] = TripoTextureAlignment.ORIGINAL_IMAGE
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model') auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
orientation: Optional[TripoOrientation] = Field(TripoOrientation.DEFAULT, description='The orientation for the model') orientation: Optional[TripoOrientation] = Field(TripoOrientation.DEFAULT, description='The orientation for the model')
@ -219,14 +239,24 @@ class TripoConvertModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description='Type of task') type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description='Type of task')
format: TripoConvertFormat = Field(..., description='The format to convert to') format: TripoConvertFormat = Field(..., description='The format to convert to')
original_model_task_id: str = Field(..., description='The task ID of the original model') original_model_task_id: str = Field(..., description='The task ID of the original model')
quad: Optional[bool] = Field(False, description='Whether to apply quad to the model') quad: Optional[bool] = Field(None, description='Whether to apply quad to the model')
force_symmetry: Optional[bool] = Field(False, description='Whether to force symmetry') force_symmetry: Optional[bool] = Field(None, description='Whether to force symmetry')
face_limit: Optional[int] = Field(10000, description='The number of faces to limit the conversion to') face_limit: Optional[int] = Field(None, description='The number of faces to limit the conversion to')
flatten_bottom: Optional[bool] = Field(False, description='Whether to flatten the bottom of the model') flatten_bottom: Optional[bool] = Field(None, description='Whether to flatten the bottom of the model')
flatten_bottom_threshold: Optional[float] = Field(0.01, description='The threshold for flattening the bottom') flatten_bottom_threshold: Optional[float] = Field(None, description='The threshold for flattening the bottom')
texture_size: Optional[int] = Field(4096, description='The size of the texture') texture_size: Optional[int] = Field(None, description='The size of the texture')
texture_format: Optional[TripoTextureFormat] = Field(TripoTextureFormat.JPEG, description='The format of the texture') texture_format: Optional[TripoTextureFormat] = Field(TripoTextureFormat.JPEG, description='The format of the texture')
pivot_to_center_bottom: Optional[bool] = Field(False, description='Whether to pivot to the center bottom') pivot_to_center_bottom: Optional[bool] = Field(None, description='Whether to pivot to the center bottom')
scale_factor: Optional[float] = Field(None, description='The scale factor for the model')
with_animation: Optional[bool] = Field(None, description='Whether to include animations')
pack_uv: Optional[bool] = Field(None, description='Whether to pack the UVs')
bake: Optional[bool] = Field(None, description='Whether to bake the model')
part_names: Optional[List[str]] = Field(None, description='The names of the parts to include')
fbx_preset: Optional[TripoFbxPreset] = Field(None, description='The preset for the FBX export')
export_vertex_colors: Optional[bool] = Field(None, description='Whether to export the vertex colors')
export_orientation: Optional[TripoOrientation] = Field(None, description='The orientation for the export')
animate_in_place: Optional[bool] = Field(None, description='Whether to animate in place')
class TripoTaskRequest(RootModel): class TripoTaskRequest(RootModel):
root: Union[ root: Union[

View File

@ -50,6 +50,7 @@ from comfy_api_nodes.apis import (
KlingSingleImageEffectModelName, KlingSingleImageEffectModelName,
) )
from comfy_api_nodes.apis.kling_api import ( from comfy_api_nodes.apis.kling_api import (
ImageToVideoWithAudioRequest,
OmniImageParamImage, OmniImageParamImage,
OmniParamImage, OmniParamImage,
OmniParamVideo, OmniParamVideo,
@ -57,7 +58,8 @@ from comfy_api_nodes.apis.kling_api import (
OmniProImageRequest, OmniProImageRequest,
OmniProReferences2VideoRequest, OmniProReferences2VideoRequest,
OmniProText2VideoRequest, OmniProText2VideoRequest,
OmniTaskStatusResponse, TaskStatusResponse,
TextToVideoWithAudioRequest,
) )
from comfy_api_nodes.util import ( from comfy_api_nodes.util import (
ApiEndpoint, ApiEndpoint,
@ -103,10 +105,6 @@ AVERAGE_DURATION_VIDEO_EXTEND = 320
MODE_TEXT2VIDEO = { MODE_TEXT2VIDEO = {
"standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"),
"standard mode / 10s duration / kling-v1": ("std", "10", "kling-v1"),
"pro mode / 5s duration / kling-v1": ("pro", "5", "kling-v1"),
"pro mode / 10s duration / kling-v1": ("pro", "10", "kling-v1"),
"standard mode / 5s duration / kling-v1-6": ("std", "5", "kling-v1-6"), "standard mode / 5s duration / kling-v1-6": ("std", "5", "kling-v1-6"),
"standard mode / 10s duration / kling-v1-6": ("std", "10", "kling-v1-6"), "standard mode / 10s duration / kling-v1-6": ("std", "10", "kling-v1-6"),
"pro mode / 5s duration / kling-v2-master": ("pro", "5", "kling-v2-master"), "pro mode / 5s duration / kling-v2-master": ("pro", "5", "kling-v2-master"),
@ -127,8 +125,6 @@ See: [Kling API Docs Capability Map](https://app.klingai.com/global/dev/document
MODE_START_END_FRAME = { MODE_START_END_FRAME = {
"standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"),
"pro mode / 5s duration / kling-v1": ("pro", "5", "kling-v1"),
"pro mode / 5s duration / kling-v1-5": ("pro", "5", "kling-v1-5"), "pro mode / 5s duration / kling-v1-5": ("pro", "5", "kling-v1-5"),
"pro mode / 10s duration / kling-v1-5": ("pro", "10", "kling-v1-5"), "pro mode / 10s duration / kling-v1-5": ("pro", "10", "kling-v1-5"),
"pro mode / 5s duration / kling-v1-6": ("pro", "5", "kling-v1-6"), "pro mode / 5s duration / kling-v1-6": ("pro", "5", "kling-v1-6"),
@ -242,7 +238,7 @@ def normalize_omni_prompt_references(prompt: str) -> str:
return re.sub(r"(?<!\w)@video(?P<idx>\d*)(?!\w)", _video_repl, prompt) return re.sub(r"(?<!\w)@video(?P<idx>\d*)(?!\w)", _video_repl, prompt)
async def finish_omni_video_task(cls: type[IO.ComfyNode], response: OmniTaskStatusResponse) -> IO.NodeOutput: async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusResponse) -> IO.NodeOutput:
if response.code: if response.code:
raise RuntimeError( raise RuntimeError(
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
@ -250,7 +246,7 @@ async def finish_omni_video_task(cls: type[IO.ComfyNode], response: OmniTaskStat
final_response = await poll_op( final_response = await poll_op(
cls, cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"), ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"),
response_model=OmniTaskStatusResponse, response_model=TaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None), status_extractor=lambda r: (r.data.task_status if r.data else None),
max_poll_attempts=160, max_poll_attempts=160,
) )
@ -483,12 +479,12 @@ async def execute_image2video(
task_id = task_creation_response.data.task_id task_id = task_creation_response.data.task_id
final_response = await poll_op( final_response = await poll_op(
cls, cls,
ApiEndpoint(path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}"), ApiEndpoint(path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}"),
response_model=KlingImage2VideoResponse, response_model=KlingImage2VideoResponse,
estimated_duration=AVERAGE_DURATION_I2V, estimated_duration=AVERAGE_DURATION_I2V,
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
) )
validate_video_result_response(final_response) validate_video_result_response(final_response)
video = get_video_from_response(final_response) video = get_video_from_response(final_response)
@ -752,7 +748,7 @@ class KlingTextToVideoNode(IO.ComfyNode):
IO.Combo.Input( IO.Combo.Input(
"mode", "mode",
options=modes, options=modes,
default=modes[4], default=modes[8],
tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.", tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.",
), ),
], ],
@ -834,7 +830,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
response = await sync_op( response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=OmniTaskStatusResponse, response_model=TaskStatusResponse,
data=OmniProText2VideoRequest( data=OmniProText2VideoRequest(
model_name=model_name, model_name=model_name,
prompt=prompt, prompt=prompt,
@ -929,7 +925,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
response = await sync_op( response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=OmniTaskStatusResponse, response_model=TaskStatusResponse,
data=OmniProFirstLastFrameRequest( data=OmniProFirstLastFrameRequest(
model_name=model_name, model_name=model_name,
prompt=prompt, prompt=prompt,
@ -997,7 +993,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
response = await sync_op( response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=OmniTaskStatusResponse, response_model=TaskStatusResponse,
data=OmniProReferences2VideoRequest( data=OmniProReferences2VideoRequest(
model_name=model_name, model_name=model_name,
prompt=prompt, prompt=prompt,
@ -1081,7 +1077,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
response = await sync_op( response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=OmniTaskStatusResponse, response_model=TaskStatusResponse,
data=OmniProReferences2VideoRequest( data=OmniProReferences2VideoRequest(
model_name=model_name, model_name=model_name,
prompt=prompt, prompt=prompt,
@ -1162,7 +1158,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
response = await sync_op( response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=OmniTaskStatusResponse, response_model=TaskStatusResponse,
data=OmniProReferences2VideoRequest( data=OmniProReferences2VideoRequest(
model_name=model_name, model_name=model_name,
prompt=prompt, prompt=prompt,
@ -1237,7 +1233,7 @@ class OmniProImageNode(IO.ComfyNode):
response = await sync_op( response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/kling/v1/images/omni-image", method="POST"), ApiEndpoint(path="/proxy/kling/v1/images/omni-image", method="POST"),
response_model=OmniTaskStatusResponse, response_model=TaskStatusResponse,
data=OmniProImageRequest( data=OmniProImageRequest(
model_name=model_name, model_name=model_name,
prompt=prompt, prompt=prompt,
@ -1253,7 +1249,7 @@ class OmniProImageNode(IO.ComfyNode):
final_response = await poll_op( final_response = await poll_op(
cls, cls,
ApiEndpoint(path=f"/proxy/kling/v1/images/omni-image/{response.data.task_id}"), ApiEndpoint(path=f"/proxy/kling/v1/images/omni-image/{response.data.task_id}"),
response_model=OmniTaskStatusResponse, response_model=TaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None), status_extractor=lambda r: (r.data.task_status if r.data else None),
) )
return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.task_result.images[0].url)) return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.task_result.images[0].url))
@ -1328,9 +1324,8 @@ class KlingImage2VideoNode(IO.ComfyNode):
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
node_id="KlingImage2VideoNode", node_id="KlingImage2VideoNode",
display_name="Kling Image to Video", display_name="Kling Image(First Frame) to Video",
category="api node/video/Kling", category="api node/video/Kling",
description="Kling Image to Video Node",
inputs=[ inputs=[
IO.Image.Input("start_frame", tooltip="The reference image used to generate the video."), IO.Image.Input("start_frame", tooltip="The reference image used to generate the video."),
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"), IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
@ -1488,7 +1483,7 @@ class KlingStartEndFrameNode(IO.ComfyNode):
IO.Combo.Input( IO.Combo.Input(
"mode", "mode",
options=modes, options=modes,
default=modes[8], default=modes[6],
tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.", tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.",
), ),
], ],
@ -1951,7 +1946,7 @@ class KlingImageGenerationNode(IO.ComfyNode):
IO.Combo.Input( IO.Combo.Input(
"model_name", "model_name",
options=[i.value for i in KlingImageGenModelName], options=[i.value for i in KlingImageGenModelName],
default="kling-v1", default="kling-v2",
), ),
IO.Combo.Input( IO.Combo.Input(
"aspect_ratio", "aspect_ratio",
@ -2034,6 +2029,136 @@ class KlingImageGenerationNode(IO.ComfyNode):
return IO.NodeOutput(await image_result_to_node_output(images)) return IO.NodeOutput(await image_result_to_node_output(images))
class TextToVideoWithAudio(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingTextToVideoWithAudio",
display_name="Kling Text to Video with Audio",
category="api node/video/Kling",
inputs=[
IO.Combo.Input("model_name", options=["kling-v2-6"]),
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."),
IO.Combo.Input("mode", options=["pro"]),
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
IO.Combo.Input("duration", options=[5, 10]),
IO.Boolean.Input("generate_audio", default=True),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model_name: str,
prompt: str,
mode: str,
aspect_ratio: str,
duration: int,
generate_audio: bool,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2500)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/text2video", method="POST"),
response_model=TaskStatusResponse,
data=TextToVideoWithAudioRequest(
model_name=model_name,
prompt=prompt,
mode=mode,
aspect_ratio=aspect_ratio,
duration=str(duration),
sound="on" if generate_audio else "off",
),
)
if response.code:
raise RuntimeError(
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
)
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/text2video/{response.data.task_id}"),
response_model=TaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
)
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
class ImageToVideoWithAudio(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingImageToVideoWithAudio",
display_name="Kling Image(First Frame) to Video with Audio",
category="api node/video/Kling",
inputs=[
IO.Combo.Input("model_name", options=["kling-v2-6"]),
IO.Image.Input("start_frame"),
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."),
IO.Combo.Input("mode", options=["pro"]),
IO.Combo.Input("duration", options=[5, 10]),
IO.Boolean.Input("generate_audio", default=True),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model_name: str,
start_frame: Input.Image,
prompt: str,
mode: str,
duration: int,
generate_audio: bool,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2500)
validate_image_dimensions(start_frame, min_width=300, min_height=300)
validate_image_aspect_ratio(start_frame, (1, 2.5), (2.5, 1))
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/image2video", method="POST"),
response_model=TaskStatusResponse,
data=ImageToVideoWithAudioRequest(
model_name=model_name,
image=(await upload_images_to_comfyapi(cls, start_frame))[0],
prompt=prompt,
mode=mode,
duration=str(duration),
sound="on" if generate_audio else "off",
),
)
if response.code:
raise RuntimeError(
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
)
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"),
response_model=TaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
)
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
class KlingExtension(ComfyExtension): class KlingExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[IO.ComfyNode]]: async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -2056,7 +2181,9 @@ class KlingExtension(ComfyExtension):
OmniProImageToVideoNode, OmniProImageToVideoNode,
OmniProVideoToVideoNode, OmniProVideoToVideoNode,
OmniProEditVideoNode, OmniProEditVideoNode,
# OmniProImageNode, # need support from backend OmniProImageNode,
TextToVideoWithAudio,
ImageToVideoWithAudio,
] ]

View File

@ -1,575 +0,0 @@
"""
Pika x ComfyUI API Nodes
Pika API docs: https://pika-827374fb.mintlify.app/api-reference
"""
from __future__ import annotations
from io import BytesIO
import logging
from typing import Optional
import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
from comfy_api_nodes.apis import pika_api as pika_defs
from comfy_api_nodes.util import (
validate_string,
download_url_to_video_output,
tensor_to_bytesio,
ApiEndpoint,
sync_op,
poll_op,
)
PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions"
PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps"
PATH_PIKAFFECTS = "/proxy/pika/generate/pikaffects"
PIKA_API_VERSION = "2.2"
PATH_TEXT_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/t2v"
PATH_IMAGE_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/i2v"
PATH_PIKAFRAMES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikaframes"
PATH_PIKASCENES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikascenes"
PATH_VIDEO_GET = "/proxy/pika/videos"
async def execute_task(
task_id: str,
cls: type[IO.ComfyNode],
) -> IO.NodeOutput:
final_response: pika_defs.PikaVideoResponse = await poll_op(
cls,
ApiEndpoint(path=f"{PATH_VIDEO_GET}/{task_id}"),
response_model=pika_defs.PikaVideoResponse,
status_extractor=lambda response: (response.status.value if response.status else None),
progress_extractor=lambda response: (response.progress if hasattr(response, "progress") else None),
estimated_duration=60,
max_poll_attempts=240,
)
if not final_response.url:
error_msg = f"Pika task {task_id} succeeded but no video data found in response:\n{final_response}"
logging.error(error_msg)
raise Exception(error_msg)
video_url = final_response.url
logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url)
return IO.NodeOutput(await download_url_to_video_output(video_url))
def get_base_inputs_types() -> list[IO.Input]:
"""Get the base required inputs types common to all Pika nodes."""
return [
IO.String.Input("prompt_text", multiline=True),
IO.String.Input("negative_prompt", multiline=True),
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
IO.Combo.Input("resolution", options=["1080p", "720p"], default="1080p"),
IO.Combo.Input("duration", options=[5, 10], default=5),
]
class PikaImageToVideo(IO.ComfyNode):
"""Pika 2.2 Image to Video Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PikaImageToVideoNode2_2",
display_name="Pika Image to Video",
description="Sends an image and prompt to the Pika API v2.2 to generate a video.",
category="api node/video/Pika",
inputs=[
IO.Image.Input("image", tooltip="The image to convert to video"),
*get_base_inputs_types(),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
image: torch.Tensor,
prompt_text: str,
negative_prompt: str,
seed: int,
resolution: str,
duration: int,
) -> IO.NodeOutput:
image_bytes_io = tensor_to_bytesio(image)
pika_files = {"image": ("image.png", image_bytes_io, "image/png")}
pika_request_data = pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
resolution=resolution,
duration=duration,
)
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikaTextToVideoNode(IO.ComfyNode):
"""Pika Text2Video v2.2 Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PikaTextToVideoNode2_2",
display_name="Pika Text to Video",
description="Sends a text prompt to the Pika API v2.2 to generate a video.",
category="api node/video/Pika",
inputs=[
*get_base_inputs_types(),
IO.Float.Input(
"aspect_ratio",
step=0.001,
min=0.4,
max=2.5,
default=1.7777777777777777,
tooltip="Aspect ratio (width / height)",
)
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
prompt_text: str,
negative_prompt: str,
seed: int,
resolution: str,
duration: int,
aspect_ratio: float,
) -> IO.NodeOutput:
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_TEXT_TO_VIDEO, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
resolution=resolution,
duration=duration,
aspectRatio=aspect_ratio,
),
content_type="application/x-www-form-urlencoded",
)
return await execute_task(initial_operation.video_id, cls)
class PikaScenes(IO.ComfyNode):
"""PikaScenes v2.2 Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PikaScenesV2_2",
display_name="Pika Scenes (Video Image Composition)",
description="Combine your images to create a video with the objects in them. Upload multiple images as ingredients and generate a high-quality video that incorporates all of them.",
category="api node/video/Pika",
inputs=[
*get_base_inputs_types(),
IO.Combo.Input(
"ingredients_mode",
options=["creative", "precise"],
default="creative",
),
IO.Float.Input(
"aspect_ratio",
step=0.001,
min=0.4,
max=2.5,
default=1.7777777777777777,
tooltip="Aspect ratio (width / height)",
),
IO.Image.Input(
"image_ingredient_1",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
IO.Image.Input(
"image_ingredient_2",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
IO.Image.Input(
"image_ingredient_3",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
IO.Image.Input(
"image_ingredient_4",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
IO.Image.Input(
"image_ingredient_5",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
prompt_text: str,
negative_prompt: str,
seed: int,
resolution: str,
duration: int,
ingredients_mode: str,
aspect_ratio: float,
image_ingredient_1: Optional[torch.Tensor] = None,
image_ingredient_2: Optional[torch.Tensor] = None,
image_ingredient_3: Optional[torch.Tensor] = None,
image_ingredient_4: Optional[torch.Tensor] = None,
image_ingredient_5: Optional[torch.Tensor] = None,
) -> IO.NodeOutput:
all_image_bytes_io = []
for image in [
image_ingredient_1,
image_ingredient_2,
image_ingredient_3,
image_ingredient_4,
image_ingredient_5,
]:
if image is not None:
all_image_bytes_io.append(tensor_to_bytesio(image))
pika_files = [
("images", (f"image_{i}.png", image_bytes_io, "image/png"))
for i, image_bytes_io in enumerate(all_image_bytes_io)
]
pika_request_data = pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost(
ingredientsMode=ingredients_mode,
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
resolution=resolution,
duration=duration,
aspectRatio=aspect_ratio,
)
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKASCENES, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikAdditionsNode(IO.ComfyNode):
"""Pika Pikadditions Node. Add an image into a video."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Pikadditions",
display_name="Pikadditions (Video Object Insertion)",
description="Add any object or image into your video. Upload a video and specify what you'd like to add to create a seamlessly integrated result.",
category="api node/video/Pika",
inputs=[
IO.Video.Input("video", tooltip="The video to add an image to."),
IO.Image.Input("image", tooltip="The image to add to the video."),
IO.String.Input("prompt_text", multiline=True),
IO.String.Input("negative_prompt", multiline=True),
IO.Int.Input(
"seed",
min=0,
max=0xFFFFFFFF,
control_after_generate=True,
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
video: VideoInput,
image: torch.Tensor,
prompt_text: str,
negative_prompt: str,
seed: int,
) -> IO.NodeOutput:
video_bytes_io = BytesIO()
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
video_bytes_io.seek(0)
image_bytes_io = tensor_to_bytesio(image)
pika_files = {
"video": ("video.mp4", video_bytes_io, "video/mp4"),
"image": ("image.png", image_bytes_io, "image/png"),
}
pika_request_data = pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
)
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKADDITIONS, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikaSwapsNode(IO.ComfyNode):
"""Pika Pikaswaps Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Pikaswaps",
display_name="Pika Swaps (Video Object Replacement)",
description="Swap out any object or region of your video with a new image or object. Define areas to replace either with a mask or coordinates.",
category="api node/video/Pika",
inputs=[
IO.Video.Input("video", tooltip="The video to swap an object in."),
IO.Image.Input(
"image",
tooltip="The image used to replace the masked object in the video.",
optional=True,
),
IO.Mask.Input(
"mask",
tooltip="Use the mask to define areas in the video to replace.",
optional=True,
),
IO.String.Input("prompt_text", multiline=True, optional=True),
IO.String.Input("negative_prompt", multiline=True, optional=True),
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True, optional=True),
IO.String.Input(
"region_to_modify",
multiline=True,
optional=True,
tooltip="Plaintext description of the object / region to modify.",
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
video: VideoInput,
image: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
prompt_text: str = "",
negative_prompt: str = "",
seed: int = 0,
region_to_modify: str = "",
) -> IO.NodeOutput:
video_bytes_io = BytesIO()
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
video_bytes_io.seek(0)
pika_files = {
"video": ("video.mp4", video_bytes_io, "video/mp4"),
}
if mask is not None:
pika_files["modifyRegionMask"] = ("mask.png", tensor_to_bytesio(mask), "image/png")
if image is not None:
pika_files["image"] = ("image.png", tensor_to_bytesio(image), "image/png")
pika_request_data = pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
modifyRegionRoi=region_to_modify if region_to_modify else None,
)
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKASWAPS, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikaffectsNode(IO.ComfyNode):
"""Pika Pikaffects Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Pikaffects",
display_name="Pikaffects (Video Effects)",
description="Generate a video with a specific Pikaffect. Supported Pikaffects: Cake-ify, Crumble, Crush, Decapitate, Deflate, Dissolve, Explode, Eye-pop, Inflate, Levitate, Melt, Peel, Poke, Squish, Ta-da, Tear",
category="api node/video/Pika",
inputs=[
IO.Image.Input("image", tooltip="The reference image to apply the Pikaffect to."),
IO.Combo.Input(
"pikaffect", options=pika_defs.Pikaffect, default="Cake-ify"
),
IO.String.Input("prompt_text", multiline=True),
IO.String.Input("negative_prompt", multiline=True),
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
image: torch.Tensor,
pikaffect: str,
prompt_text: str,
negative_prompt: str,
seed: int,
) -> IO.NodeOutput:
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKAFFECTS, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
pikaffect=pikaffect,
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
),
files={"image": ("image.png", tensor_to_bytesio(image), "image/png")},
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikaStartEndFrameNode(IO.ComfyNode):
"""PikaFrames v2.2 Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PikaStartEndFrameNode2_2",
display_name="Pika Start and End Frame to Video",
description="Generate a video by combining your first and last frame. Upload two images to define the start and end points, and let the AI create a smooth transition between them.",
category="api node/video/Pika",
inputs=[
IO.Image.Input("image_start", tooltip="The first image to combine."),
IO.Image.Input("image_end", tooltip="The last image to combine."),
*get_base_inputs_types(),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
image_start: torch.Tensor,
image_end: torch.Tensor,
prompt_text: str,
negative_prompt: str,
seed: int,
resolution: str,
duration: int,
) -> IO.NodeOutput:
validate_string(prompt_text, field_name="prompt_text", min_length=1)
pika_files = [
("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")),
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
]
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKAFRAMES, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
resolution=resolution,
duration=duration,
),
files=pika_files,
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikaApiNodesExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
PikaImageToVideo,
PikaTextToVideoNode,
PikaScenes,
PikAdditionsNode,
PikaSwapsNode,
PikaffectsNode,
PikaStartEndFrameNode,
]
async def comfy_entrypoint() -> PikaApiNodesExtension:
return PikaApiNodesExtension()

View File

@ -102,8 +102,9 @@ class TripoTextToModelNode(IO.ComfyNode):
IO.Int.Input("model_seed", default=42, optional=True), IO.Int.Input("model_seed", default=42, optional=True),
IO.Int.Input("texture_seed", default=42, optional=True), IO.Int.Input("texture_seed", default=42, optional=True),
IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True), IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True),
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True), IO.Int.Input("face_limit", default=-1, min=-1, max=2000000, optional=True),
IO.Boolean.Input("quad", default=False, optional=True), IO.Boolean.Input("quad", default=False, optional=True),
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
], ],
outputs=[ outputs=[
IO.String.Output(display_name="model_file"), IO.String.Output(display_name="model_file"),
@ -131,6 +132,7 @@ class TripoTextToModelNode(IO.ComfyNode):
model_seed: Optional[int] = None, model_seed: Optional[int] = None,
texture_seed: Optional[int] = None, texture_seed: Optional[int] = None,
texture_quality: Optional[str] = None, texture_quality: Optional[str] = None,
geometry_quality: Optional[str] = None,
face_limit: Optional[int] = None, face_limit: Optional[int] = None,
quad: Optional[bool] = None, quad: Optional[bool] = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
@ -154,6 +156,7 @@ class TripoTextToModelNode(IO.ComfyNode):
texture_seed=texture_seed, texture_seed=texture_seed,
texture_quality=texture_quality, texture_quality=texture_quality,
face_limit=face_limit, face_limit=face_limit,
geometry_quality=geometry_quality,
auto_size=True, auto_size=True,
quad=quad, quad=quad,
), ),
@ -194,6 +197,7 @@ class TripoImageToModelNode(IO.ComfyNode):
), ),
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True), IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True),
IO.Boolean.Input("quad", default=False, optional=True), IO.Boolean.Input("quad", default=False, optional=True),
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
], ],
outputs=[ outputs=[
IO.String.Output(display_name="model_file"), IO.String.Output(display_name="model_file"),
@ -220,6 +224,7 @@ class TripoImageToModelNode(IO.ComfyNode):
orientation=None, orientation=None,
texture_seed: Optional[int] = None, texture_seed: Optional[int] = None,
texture_quality: Optional[str] = None, texture_quality: Optional[str] = None,
geometry_quality: Optional[str] = None,
texture_alignment: Optional[str] = None, texture_alignment: Optional[str] = None,
face_limit: Optional[int] = None, face_limit: Optional[int] = None,
quad: Optional[bool] = None, quad: Optional[bool] = None,
@ -246,6 +251,7 @@ class TripoImageToModelNode(IO.ComfyNode):
pbr=pbr, pbr=pbr,
model_seed=model_seed, model_seed=model_seed,
orientation=orientation, orientation=orientation,
geometry_quality=geometry_quality,
texture_alignment=texture_alignment, texture_alignment=texture_alignment,
texture_seed=texture_seed, texture_seed=texture_seed,
texture_quality=texture_quality, texture_quality=texture_quality,
@ -295,6 +301,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
), ),
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True), IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True),
IO.Boolean.Input("quad", default=False, optional=True), IO.Boolean.Input("quad", default=False, optional=True),
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
], ],
outputs=[ outputs=[
IO.String.Output(display_name="model_file"), IO.String.Output(display_name="model_file"),
@ -323,6 +330,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
model_seed: Optional[int] = None, model_seed: Optional[int] = None,
texture_seed: Optional[int] = None, texture_seed: Optional[int] = None,
texture_quality: Optional[str] = None, texture_quality: Optional[str] = None,
geometry_quality: Optional[str] = None,
texture_alignment: Optional[str] = None, texture_alignment: Optional[str] = None,
face_limit: Optional[int] = None, face_limit: Optional[int] = None,
quad: Optional[bool] = None, quad: Optional[bool] = None,
@ -359,6 +367,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
model_seed=model_seed, model_seed=model_seed,
texture_seed=texture_seed, texture_seed=texture_seed,
texture_quality=texture_quality, texture_quality=texture_quality,
geometry_quality=geometry_quality,
texture_alignment=texture_alignment, texture_alignment=texture_alignment,
face_limit=face_limit, face_limit=face_limit,
quad=quad, quad=quad,
@ -508,6 +517,8 @@ class TripoRetargetNode(IO.ComfyNode):
options=[ options=[
"preset:idle", "preset:idle",
"preset:walk", "preset:walk",
"preset:run",
"preset:dive",
"preset:climb", "preset:climb",
"preset:jump", "preset:jump",
"preset:slash", "preset:slash",
@ -515,6 +526,11 @@ class TripoRetargetNode(IO.ComfyNode):
"preset:hurt", "preset:hurt",
"preset:fall", "preset:fall",
"preset:turn", "preset:turn",
"preset:quadruped:walk",
"preset:hexapod:walk",
"preset:octopod:walk",
"preset:serpentine:march",
"preset:aquatic:march"
], ],
), ),
], ],
@ -563,7 +579,7 @@ class TripoConversionNode(IO.ComfyNode):
"face_limit", "face_limit",
default=-1, default=-1,
min=-1, min=-1,
max=500000, max=2000000,
optional=True, optional=True,
), ),
IO.Int.Input( IO.Int.Input(
@ -579,6 +595,40 @@ class TripoConversionNode(IO.ComfyNode):
default="JPEG", default="JPEG",
optional=True, optional=True,
), ),
IO.Boolean.Input("force_symmetry", default=False, optional=True),
IO.Boolean.Input("flatten_bottom", default=False, optional=True),
IO.Float.Input(
"flatten_bottom_threshold",
default=0.0,
min=0.0,
max=1.0,
optional=True,
),
IO.Boolean.Input("pivot_to_center_bottom", default=False, optional=True),
IO.Float.Input(
"scale_factor",
default=1.0,
min=0.0,
optional=True,
),
IO.Boolean.Input("with_animation", default=False, optional=True),
IO.Boolean.Input("pack_uv", default=False, optional=True),
IO.Boolean.Input("bake", default=False, optional=True),
IO.String.Input("part_names", default="", optional=True), # comma-separated list
IO.Combo.Input(
"fbx_preset",
options=["blender", "mixamo", "3dsmax"],
default="blender",
optional=True,
),
IO.Boolean.Input("export_vertex_colors", default=False, optional=True),
IO.Combo.Input(
"export_orientation",
options=["align_image", "default"],
default="default",
optional=True,
),
IO.Boolean.Input("animate_in_place", default=False, optional=True),
], ],
outputs=[], outputs=[],
hidden=[ hidden=[
@ -604,12 +654,31 @@ class TripoConversionNode(IO.ComfyNode):
original_model_task_id, original_model_task_id,
format: str, format: str,
quad: bool, quad: bool,
force_symmetry: bool,
face_limit: int, face_limit: int,
flatten_bottom: bool,
flatten_bottom_threshold: float,
texture_size: int, texture_size: int,
texture_format: str, texture_format: str,
pivot_to_center_bottom: bool,
scale_factor: float,
with_animation: bool,
pack_uv: bool,
bake: bool,
part_names: str,
fbx_preset: str,
export_vertex_colors: bool,
export_orientation: str,
animate_in_place: bool,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
if not original_model_task_id: if not original_model_task_id:
raise RuntimeError("original_model_task_id is required") raise RuntimeError("original_model_task_id is required")
# Parse part_names from comma-separated string to list
part_names_list = None
if part_names and part_names.strip():
part_names_list = [name.strip() for name in part_names.split(',') if name.strip()]
response = await sync_op( response = await sync_op(
cls, cls,
endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"),
@ -618,9 +687,22 @@ class TripoConversionNode(IO.ComfyNode):
original_model_task_id=original_model_task_id, original_model_task_id=original_model_task_id,
format=format, format=format,
quad=quad if quad else None, quad=quad if quad else None,
force_symmetry=force_symmetry if force_symmetry else None,
face_limit=face_limit if face_limit != -1 else None, face_limit=face_limit if face_limit != -1 else None,
flatten_bottom=flatten_bottom if flatten_bottom else None,
flatten_bottom_threshold=flatten_bottom_threshold if flatten_bottom_threshold != 0.0 else None,
texture_size=texture_size if texture_size != 4096 else None, texture_size=texture_size if texture_size != 4096 else None,
texture_format=texture_format if texture_format != "JPEG" else None, texture_format=texture_format if texture_format != "JPEG" else None,
pivot_to_center_bottom=pivot_to_center_bottom if pivot_to_center_bottom else None,
scale_factor=scale_factor if scale_factor != 1.0 else None,
with_animation=with_animation if with_animation else None,
pack_uv=pack_uv if pack_uv else None,
bake=bake if bake else None,
part_names=part_names_list,
fbx_preset=fbx_preset if fbx_preset != "blender" else None,
export_vertex_colors=export_vertex_colors if export_vertex_colors else None,
export_orientation=export_orientation if export_orientation != "default" else None,
animate_in_place=animate_in_place if animate_in_place else None,
), ),
) )
return await poll_until_finished(cls, response, average_duration=30) return await poll_until_finished(cls, response, average_duration=30)

View File

@ -659,6 +659,31 @@ class SamplerSASolver(io.ComfyNode):
get_sampler = execute get_sampler = execute
class SamplerSEEDS2(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SamplerSEEDS2",
category="sampling/custom_sampling/samplers",
inputs=[
io.Combo.Input("solver_type", options=["phi_1", "phi_2"]),
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength"),
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="SDE noise multiplier"),
io.Float.Input("r", default=0.5, min=0.01, max=1.0, step=0.01, round=False, tooltip="Relative step size for the intermediate stage (c2 node)"),
],
outputs=[io.Sampler.Output()]
)
@classmethod
def execute(cls, solver_type, eta, s_noise, r) -> io.NodeOutput:
sampler_name = "seeds_2"
sampler = comfy.samplers.ksampler(
sampler_name,
{"eta": eta, "s_noise": s_noise, "r": r, "solver_type": solver_type},
)
return io.NodeOutput(sampler)
class Noise_EmptyNoise: class Noise_EmptyNoise:
def __init__(self): def __init__(self):
self.seed = 0 self.seed = 0
@ -996,6 +1021,7 @@ class CustomSamplersExtension(ComfyExtension):
SamplerDPMAdaptative, SamplerDPMAdaptative,
SamplerER_SDE, SamplerER_SDE,
SamplerSASolver, SamplerSASolver,
SamplerSEEDS2,
SplitSigmas, SplitSigmas,
SplitSigmasDenoise, SplitSigmasDenoise,
FlipSigmas, FlipSigmas,

View File

@ -154,12 +154,13 @@ class FluxKontextMultiReferenceLatentMethod(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="FluxKontextMultiReferenceLatentMethod", node_id="FluxKontextMultiReferenceLatentMethod",
display_name="Edit Model Reference Method",
category="advanced/conditioning/flux", category="advanced/conditioning/flux",
inputs=[ inputs=[
io.Conditioning.Input("conditioning"), io.Conditioning.Input("conditioning"),
io.Combo.Input( io.Combo.Input(
"reference_latents_method", "reference_latents_method",
options=["offset", "index", "uxo/uno"], options=["offset", "index", "uxo/uno", "index_timestep_zero"],
), ),
], ],
outputs=[ outputs=[

View File

@ -243,7 +243,16 @@ class ModelPatchLoader:
model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
elif 'control_all_x_embedder.2-1.weight' in sd: # alipai z image fun controlnet elif 'control_all_x_embedder.2-1.weight' in sd: # alipai z image fun controlnet
sd = z_image_convert(sd) sd = z_image_convert(sd)
model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) config = {}
if 'control_layers.14.adaLN_modulation.0.weight' in sd:
config['n_control_layers'] = 15
config['additional_in_dim'] = 17
config['refiner_control'] = True
ref_weight = sd.get("control_noise_refiner.0.after_proj.weight", None)
if ref_weight is not None:
if torch.count_nonzero(ref_weight) == 0:
config['broken'] = True
model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast, **config)
model.load_state_dict(sd) model.load_state_dict(sd)
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
@ -297,62 +306,122 @@ class DiffSynthCnetPatch:
return [self.model_patch] return [self.model_patch]
class ZImageControlPatch: class ZImageControlPatch:
def __init__(self, model_patch, vae, image, strength): def __init__(self, model_patch, vae, image, strength, inpaint_image=None, mask=None):
self.model_patch = model_patch self.model_patch = model_patch
self.vae = vae self.vae = vae
self.image = image self.image = image
self.inpaint_image = inpaint_image
self.mask = mask
self.strength = strength self.strength = strength
self.encoded_image = self.encode_latent_cond(image) self.is_inpaint = self.model_patch.model.additional_in_dim > 0
self.encoded_image_size = (image.shape[1], image.shape[2])
skip_encoding = False
if self.image is not None and self.inpaint_image is not None:
if self.image.shape != self.inpaint_image.shape:
skip_encoding = True
if skip_encoding:
self.encoded_image = None
else:
self.encoded_image = self.encode_latent_cond(self.image, self.inpaint_image)
if self.image is None:
self.encoded_image_size = (self.inpaint_image.shape[1], self.inpaint_image.shape[2])
else:
self.encoded_image_size = (self.image.shape[1], self.image.shape[2])
self.temp_data = None self.temp_data = None
def encode_latent_cond(self, image): def encode_latent_cond(self, control_image=None, inpaint_image=None):
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(image)) latent_image = None
return latent_image if control_image is not None:
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(control_image))
if self.is_inpaint:
if inpaint_image is None:
inpaint_image = torch.ones_like(control_image) * 0.5
if self.mask is not None:
mask_inpaint = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True), inpaint_image.shape[-2], inpaint_image.shape[-3], "bilinear", "center")
inpaint_image = ((inpaint_image - 0.5) * mask_inpaint.movedim(1, -1).round()) + 0.5
inpaint_image_latent = comfy.latent_formats.Flux().process_in(self.vae.encode(inpaint_image))
if self.mask is None:
mask_ = torch.zeros_like(inpaint_image_latent)[:, :1]
else:
mask_ = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True), inpaint_image_latent.shape[-1], inpaint_image_latent.shape[-2], "nearest", "center")
if latent_image is None:
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(torch.ones_like(inpaint_image) * 0.5))
return torch.cat([latent_image, mask_, inpaint_image_latent], dim=1)
else:
return latent_image
def __call__(self, kwargs): def __call__(self, kwargs):
x = kwargs.get("x") x = kwargs.get("x")
img = kwargs.get("img") img = kwargs.get("img")
img_input = kwargs.get("img_input")
txt = kwargs.get("txt") txt = kwargs.get("txt")
pe = kwargs.get("pe") pe = kwargs.get("pe")
vec = kwargs.get("vec") vec = kwargs.get("vec")
block_index = kwargs.get("block_index") block_index = kwargs.get("block_index")
block_type = kwargs.get("block_type", "")
spacial_compression = self.vae.spacial_compression_encode() spacial_compression = self.vae.spacial_compression_encode()
if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression): if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression):
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center") image_scaled = None
if self.image is not None:
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center").movedim(1, -1)
self.encoded_image_size = (image_scaled.shape[-3], image_scaled.shape[-2])
inpaint_scaled = None
if self.inpaint_image is not None:
inpaint_scaled = comfy.utils.common_upscale(self.inpaint_image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center").movedim(1, -1)
self.encoded_image_size = (inpaint_scaled.shape[-3], inpaint_scaled.shape[-2])
loaded_models = comfy.model_management.loaded_models(only_currently_used=True) loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
self.encoded_image = self.encode_latent_cond(image_scaled.movedim(1, -1)) self.encoded_image = self.encode_latent_cond(image_scaled, inpaint_scaled)
self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1])
comfy.model_management.load_models_gpu(loaded_models) comfy.model_management.load_models_gpu(loaded_models)
cnet_index = (block_index // 5) cnet_blocks = self.model_patch.model.n_control_layers
cnet_index_float = (block_index / 5) div = round(30 / cnet_blocks)
cnet_index = (block_index // div)
cnet_index_float = (block_index / div)
kwargs.pop("img") # we do ops in place kwargs.pop("img") # we do ops in place
kwargs.pop("txt") kwargs.pop("txt")
cnet_blocks = self.model_patch.model.n_control_layers
if cnet_index_float > (cnet_blocks - 1): if cnet_index_float > (cnet_blocks - 1):
self.temp_data = None self.temp_data = None
return kwargs return kwargs
if self.temp_data is None or self.temp_data[0] > cnet_index: if self.temp_data is None or self.temp_data[0] > cnet_index:
self.temp_data = (-1, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec))) if block_type == "noise_refiner":
self.temp_data = (-3, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
else:
self.temp_data = (-1, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
while self.temp_data[0] < cnet_index and (self.temp_data[0] + 1) < cnet_blocks: if block_type == "noise_refiner":
next_layer = self.temp_data[0] + 1 next_layer = self.temp_data[0] + 1
self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img[:, :self.temp_data[1][1].shape[1]], None, pe, vec)) self.temp_data = (next_layer, self.model_patch.model.forward_noise_refiner_block(block_index, self.temp_data[1][1], img_input[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
if self.temp_data[1][0] is not None:
img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
else:
while self.temp_data[0] < cnet_index and (self.temp_data[0] + 1) < cnet_blocks:
next_layer = self.temp_data[0] + 1
self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img_input[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
if cnet_index_float == self.temp_data[0]: if cnet_index_float == self.temp_data[0]:
img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength) img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
if cnet_blocks == self.temp_data[0] + 1: if cnet_blocks == self.temp_data[0] + 1:
self.temp_data = None self.temp_data = None
return kwargs return kwargs
def to(self, device_or_dtype): def to(self, device_or_dtype):
if isinstance(device_or_dtype, torch.device): if isinstance(device_or_dtype, torch.device):
self.encoded_image = self.encoded_image.to(device_or_dtype) if self.encoded_image is not None:
self.encoded_image = self.encoded_image.to(device_or_dtype)
self.temp_data = None self.temp_data = None
return self return self
@ -375,9 +444,12 @@ class QwenImageDiffsynthControlnet:
CATEGORY = "advanced/loaders/qwen" CATEGORY = "advanced/loaders/qwen"
def diffsynth_controlnet(self, model, model_patch, vae, image, strength, mask=None): def diffsynth_controlnet(self, model, model_patch, vae, image=None, strength=1.0, inpaint_image=None, mask=None):
model_patched = model.clone() model_patched = model.clone()
image = image[:, :, :, :3] if image is not None:
image = image[:, :, :, :3]
if inpaint_image is not None:
inpaint_image = inpaint_image[:, :, :, :3]
if mask is not None: if mask is not None:
if mask.ndim == 3: if mask.ndim == 3:
mask = mask.unsqueeze(1) mask = mask.unsqueeze(1)
@ -386,11 +458,24 @@ class QwenImageDiffsynthControlnet:
mask = 1.0 - mask mask = 1.0 - mask
if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control): if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control):
model_patched.set_model_double_block_patch(ZImageControlPatch(model_patch, vae, image, strength)) patch = ZImageControlPatch(model_patch, vae, image, strength, inpaint_image=inpaint_image, mask=mask)
model_patched.set_model_noise_refiner_patch(patch)
model_patched.set_model_double_block_patch(patch)
else: else:
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask)) model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
return (model_patched,) return (model_patched,)
class ZImageFunControlnet(QwenImageDiffsynthControlnet):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"model_patch": ("MODEL_PATCH",),
"vae": ("VAE",),
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
},
"optional": {"image": ("IMAGE",), "inpaint_image": ("IMAGE",), "mask": ("MASK",)}}
CATEGORY = "advanced/loaders/zimage"
class UsoStyleProjectorPatch: class UsoStyleProjectorPatch:
def __init__(self, model_patch, encoded_image): def __init__(self, model_patch, encoded_image):
@ -438,5 +523,6 @@ class USOStyleReference:
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"ModelPatchLoader": ModelPatchLoader, "ModelPatchLoader": ModelPatchLoader,
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet, "QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
"ZImageFunControlnet": ZImageFunControlnet,
"USOStyleReference": USOStyleReference, "USOStyleReference": USOStyleReference,
} }

View File

@ -2,6 +2,8 @@ from typing_extensions import override
from comfy_api.latest import ComfyExtension, io from comfy_api.latest import ComfyExtension, io
from comfy_api.torch_helpers import set_torch_compile_wrapper from comfy_api.torch_helpers import set_torch_compile_wrapper
def skip_torch_compile_dict(guard_entries):
return [("transformer_options" not in entry.name) for entry in guard_entries]
class TorchCompileModel(io.ComfyNode): class TorchCompileModel(io.ComfyNode):
@classmethod @classmethod
@ -23,7 +25,7 @@ class TorchCompileModel(io.ComfyNode):
@classmethod @classmethod
def execute(cls, model, backend) -> io.NodeOutput: def execute(cls, model, backend) -> io.NodeOutput:
m = model.clone() m = model.clone()
set_torch_compile_wrapper(model=m, backend=backend) set_torch_compile_wrapper(model=m, backend=backend, options={"guard_filter_fn": skip_torch_compile_dict})
return io.NodeOutput(m) return io.NodeOutput(m)

View File

@ -0,0 +1,535 @@
import nodes
import node_helpers
import torch
import torchvision.transforms.functional as TF
import comfy.model_management
import comfy.utils
import numpy as np
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from comfy_extras.nodes_wan import parse_json_tracks
# https://github.com/ali-vilab/Wan-Move/blob/main/wan/modules/trajectory.py
from PIL import Image, ImageDraw
SKIP_ZERO = False
def get_pos_emb(
pos_k: torch.Tensor, # A 1D tensor containing positions for which to generate embeddings.
pos_emb_dim: int,
theta_func: callable = lambda i, d: torch.pow(10000, torch.mul(2, torch.div(i.to(torch.float32), d))), #Function to compute thetas based on position and embedding dimensions.
device: torch.device = torch.device("cpu"),
dtype: torch.dtype = torch.float32,
) -> torch.Tensor: # The position embeddings (batch_size, pos_emb_dim)
assert pos_emb_dim % 2 == 0, "The dimension of position embeddings must be even."
pos_k = pos_k.to(device, dtype)
if SKIP_ZERO:
pos_k = pos_k + 1
batch_size = pos_k.size(0)
denominator = torch.arange(0, pos_emb_dim // 2, device=device, dtype=dtype)
# Expand denominator to match the shape needed for broadcasting
denominator_expanded = denominator.view(1, -1).expand(batch_size, -1)
thetas = theta_func(denominator_expanded, pos_emb_dim)
# Ensure pos_k is in the correct shape for broadcasting
pos_k_expanded = pos_k.view(-1, 1).to(dtype)
sin_thetas = torch.sin(torch.div(pos_k_expanded, thetas))
cos_thetas = torch.cos(torch.div(pos_k_expanded, thetas))
# Concatenate sine and cosine embeddings along the last dimension
pos_emb = torch.cat([sin_thetas, cos_thetas], dim=-1)
return pos_emb
def create_pos_embeddings(
pred_tracks: torch.Tensor, # the predicted tracks, [T, N, 2]
pred_visibility: torch.Tensor, # the predicted visibility [T, N]
downsample_ratios: list[int], # the ratios for downsampling time, height, and width
height: int, # the height of the feature map
width: int, # the width of the feature map
track_num: int = -1, # the number of tracks to use
t_down_strategy: str = "sample", # the strategy for downsampling time dimension
):
assert t_down_strategy in ["sample", "average"], "Invalid strategy for downsampling time dimension."
t, n, _ = pred_tracks.shape
t_down, h_down, w_down = downsample_ratios
track_pos = - torch.ones(n, (t-1) // t_down + 1, 2, dtype=torch.long)
if track_num == -1:
track_num = n
tracks_idx = torch.randperm(n)[:track_num]
tracks = pred_tracks[:, tracks_idx]
visibility = pred_visibility[:, tracks_idx]
for t_idx in range(0, t, t_down):
if t_down_strategy == "sample" or t_idx == 0:
cur_tracks = tracks[t_idx] # [N, 2]
cur_visibility = visibility[t_idx] # [N]
else:
cur_tracks = tracks[t_idx:t_idx+t_down].mean(dim=0)
cur_visibility = torch.any(visibility[t_idx:t_idx+t_down], dim=0)
for i in range(track_num):
if not cur_visibility[i] or cur_tracks[i][0] < 0 or cur_tracks[i][1] < 0 or cur_tracks[i][0] >= width or cur_tracks[i][1] >= height:
continue
x, y = cur_tracks[i]
x, y = int(x // w_down), int(y // h_down)
track_pos[i, t_idx // t_down, 0], track_pos[i, t_idx // t_down, 1] = y, x
return track_pos # the position embeddings, [N, T', 2], 2 = height, width
def replace_feature(
vae_feature: torch.Tensor, # [B, C', T', H', W']
track_pos: torch.Tensor, # [B, N, T', 2]
strength: float = 1.0
) -> torch.Tensor:
b, _, t, h, w = vae_feature.shape
assert b == track_pos.shape[0], "Batch size mismatch."
n = track_pos.shape[1]
# Shuffle the trajectory order
track_pos = track_pos[:, torch.randperm(n), :, :]
# Extract coordinates at time steps ≥ 1 and generate a valid mask
current_pos = track_pos[:, :, 1:, :] # [B, N, T-1, 2]
mask = (current_pos[..., 0] >= 0) & (current_pos[..., 1] >= 0) # [B, N, T-1]
# Get all valid indices
valid_indices = mask.nonzero(as_tuple=False) # [num_valid, 3]
num_valid = valid_indices.shape[0]
if num_valid == 0:
return vae_feature
# Decompose valid indices into each dimension
batch_idx = valid_indices[:, 0]
track_idx = valid_indices[:, 1]
t_rel = valid_indices[:, 2]
t_target = t_rel + 1 # Convert to original time step indices
# Extract target position coordinates
h_target = current_pos[batch_idx, track_idx, t_rel, 0].long() # Ensure integer indices
w_target = current_pos[batch_idx, track_idx, t_rel, 1].long()
# Extract source position coordinates (t=0)
h_source = track_pos[batch_idx, track_idx, 0, 0].long()
w_source = track_pos[batch_idx, track_idx, 0, 1].long()
# Get source features and assign to target positions
src_features = vae_feature[batch_idx, :, 0, h_source, w_source]
dst_features = vae_feature[batch_idx, :, t_target, h_target, w_target]
vae_feature[batch_idx, :, t_target, h_target, w_target] = dst_features + (src_features - dst_features) * strength
return vae_feature
# Visualize functions
def _draw_gradient_polyline_on_overlay(overlay, line_width, points, start_color, opacity=1.0):
draw = ImageDraw.Draw(overlay, 'RGBA')
points = points[::-1]
# Compute total length
total_length = 0
segment_lengths = []
for i in range(len(points) - 1):
dx = points[i + 1][0] - points[i][0]
dy = points[i + 1][1] - points[i][1]
length = (dx * dx + dy * dy) ** 0.5
segment_lengths.append(length)
total_length += length
if total_length == 0:
return
accumulated_length = 0
# Draw the gradient polyline
for idx, (start_point, end_point) in enumerate(zip(points[:-1], points[1:])):
segment_length = segment_lengths[idx]
steps = max(int(segment_length), 1)
for i in range(steps):
current_length = accumulated_length + (i / steps) * segment_length
ratio = current_length / total_length
alpha = int(255 * (1 - ratio) * opacity)
color = (*start_color, alpha)
x = int(start_point[0] + (end_point[0] - start_point[0]) * i / steps)
y = int(start_point[1] + (end_point[1] - start_point[1]) * i / steps)
dynamic_line_width = max(int(line_width * (1 - ratio)), 1)
draw.line([(x, y), (x + 1, y)], fill=color, width=dynamic_line_width)
accumulated_length += segment_length
def add_weighted(rgb, track):
rgb = np.array(rgb) # [H, W, C] "RGB"
track = np.array(track) # [H, W, C] "RGBA"
alpha = track[:, :, 3] / 255.0
alpha = np.stack([alpha] * 3, axis=-1)
blend_img = track[:, :, :3] * alpha + rgb * (1 - alpha)
return Image.fromarray(blend_img.astype(np.uint8))
def draw_tracks_on_video(video, tracks, visibility=None, track_frame=24, circle_size=12, opacity=0.5, line_width=16):
color_map = [(102, 153, 255), (0, 255, 255), (255, 255, 0), (255, 102, 204), (0, 255, 0)]
video = video.byte().cpu().numpy() # (81, 480, 832, 3)
tracks = tracks[0].long().detach().cpu().numpy()
if visibility is not None:
visibility = visibility[0].detach().cpu().numpy()
num_frames, height, width = video.shape[:3]
num_tracks = tracks.shape[1]
alpha_opacity = int(255 * opacity)
output_frames = []
for t in range(num_frames):
frame_rgb = video[t].astype(np.float32)
# Create a single RGBA overlay for all tracks in this frame
overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0))
draw_overlay = ImageDraw.Draw(overlay)
polyline_data = []
# Draw all circles on a single overlay
for n in range(num_tracks):
if visibility is not None and visibility[t, n] == 0:
continue
track_coord = tracks[t, n]
color = color_map[n % len(color_map)]
circle_color = color + (alpha_opacity,)
draw_overlay.ellipse((track_coord[0] - circle_size, track_coord[1] - circle_size, track_coord[0] + circle_size, track_coord[1] + circle_size),
fill=circle_color
)
# Store polyline data for batch processing
tracks_coord = tracks[max(t - track_frame, 0):t + 1, n]
if len(tracks_coord) > 1:
polyline_data.append((tracks_coord, color))
# Blend circles overlay once
overlay_np = np.array(overlay)
alpha = overlay_np[:, :, 3:4] / 255.0
frame_rgb = overlay_np[:, :, :3] * alpha + frame_rgb * (1 - alpha)
# Draw all polylines on a single overlay
if polyline_data:
polyline_overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0))
for tracks_coord, color in polyline_data:
_draw_gradient_polyline_on_overlay(polyline_overlay, line_width, tracks_coord, color, opacity)
# Blend polylines overlay once
polyline_np = np.array(polyline_overlay)
alpha = polyline_np[:, :, 3:4] / 255.0
frame_rgb = polyline_np[:, :, :3] * alpha + frame_rgb * (1 - alpha)
output_frames.append(Image.fromarray(frame_rgb.astype(np.uint8)))
return output_frames
class WanMoveVisualizeTracks(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanMoveVisualizeTracks",
category="conditioning/video_models",
inputs=[
io.Image.Input("images"),
io.Tracks.Input("tracks", optional=True),
io.Int.Input("line_resolution", default=24, min=1, max=1024),
io.Int.Input("circle_size", default=12, min=1, max=128),
io.Float.Input("opacity", default=0.75, min=0.0, max=1.0, step=0.01),
io.Int.Input("line_width", default=16, min=1, max=128),
],
outputs=[
io.Image.Output(),
],
)
@classmethod
def execute(cls, images, line_resolution, circle_size, opacity, line_width, tracks=None) -> io.NodeOutput:
if tracks is None:
return io.NodeOutput(images)
track_path = tracks["track_path"].unsqueeze(0)
track_visibility = tracks["track_visibility"].unsqueeze(0)
images_in = images * 255.0
if images_in.shape[0] != track_path.shape[1]:
repeat_count = track_path.shape[1] // images.shape[0]
images_in = images_in.repeat(repeat_count, 1, 1, 1)
track_video = draw_tracks_on_video(images_in, track_path, track_visibility, track_frame=line_resolution, circle_size=circle_size, opacity=opacity, line_width=line_width)
track_video = torch.stack([TF.to_tensor(frame) for frame in track_video], dim=0).movedim(1, -1).float()
return io.NodeOutput(track_video.to(comfy.model_management.intermediate_device()))
class WanMoveTracksFromCoords(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanMoveTracksFromCoords",
category="conditioning/video_models",
inputs=[
io.String.Input("track_coords", force_input=True, default="[]", optional=True),
io.Mask.Input("track_mask", optional=True),
],
outputs=[
io.Tracks.Output(),
io.Int.Output(display_name="track_length"),
],
)
@classmethod
def execute(cls, track_coords, track_mask=None) -> io.NodeOutput:
device=comfy.model_management.intermediate_device()
tracks_data = parse_json_tracks(track_coords)
track_length = len(tracks_data[0])
track_list = [
[[track[frame]['x'], track[frame]['y']] for track in tracks_data]
for frame in range(len(tracks_data[0]))
]
tracks = torch.tensor(track_list, dtype=torch.float32, device=device) # [frames, num_tracks, 2]
num_tracks = tracks.shape[-2]
if track_mask is None:
track_visibility = torch.ones((track_length, num_tracks), dtype=torch.bool, device=device)
else:
track_visibility = (track_mask > 0).any(dim=(1, 2)).unsqueeze(-1)
out_track_info = {}
out_track_info["track_path"] = tracks
out_track_info["track_visibility"] = track_visibility
return io.NodeOutput(out_track_info, track_length)
class GenerateTracks(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="GenerateTracks",
category="conditioning/video_models",
inputs=[
io.Int.Input("width", default=832, min=16, max=4096, step=16),
io.Int.Input("height", default=480, min=16, max=4096, step=16),
io.Float.Input("start_x", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized X coordinate (0-1) for start position."),
io.Float.Input("start_y", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized Y coordinate (0-1) for start position."),
io.Float.Input("end_x", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized X coordinate (0-1) for end position."),
io.Float.Input("end_y", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized Y coordinate (0-1) for end position."),
io.Int.Input("num_frames", default=81, min=1, max=1024),
io.Int.Input("num_tracks", default=5, min=1, max=100),
io.Float.Input("track_spread", default=0.025, min=0.0, max=1.0, step=0.001, tooltip="Normalized distance between tracks. Tracks are spread perpendicular to the motion direction."),
io.Boolean.Input("bezier", default=False, tooltip="Enable Bezier curve path using the mid point as control point."),
io.Float.Input("mid_x", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Normalized X control point for Bezier curve. Only used when 'bezier' is enabled."),
io.Float.Input("mid_y", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Normalized Y control point for Bezier curve. Only used when 'bezier' is enabled."),
io.Combo.Input(
"interpolation",
options=["linear", "ease_in", "ease_out", "ease_in_out", "constant"],
tooltip="Controls the timing/speed of movement along the path.",
),
io.Mask.Input("track_mask", optional=True, tooltip="Optional mask to indicate visible frames."),
],
outputs=[
io.Tracks.Output(),
io.Int.Output(display_name="track_length"),
],
)
@classmethod
def execute(cls, width, height, start_x, start_y, mid_x, mid_y, end_x, end_y, num_frames, num_tracks,
track_spread, bezier=False, interpolation="linear", track_mask=None) -> io.NodeOutput:
device = comfy.model_management.intermediate_device()
track_length = num_frames
# normalized coordinates to pixel coordinates
start_x_px = start_x * width
start_y_px = start_y * height
mid_x_px = mid_x * width
mid_y_px = mid_y * height
end_x_px = end_x * width
end_y_px = end_y * height
track_spread_px = track_spread * (width + height) / 2 # Use average of width/height for spread to keep it proportional
t = torch.linspace(0, 1, num_frames, device=device)
if interpolation == "constant": # All points stay at start position
interp_values = torch.zeros_like(t)
elif interpolation == "linear":
interp_values = t
elif interpolation == "ease_in":
interp_values = t ** 2
elif interpolation == "ease_out":
interp_values = 1 - (1 - t) ** 2
elif interpolation == "ease_in_out":
interp_values = t * t * (3 - 2 * t)
if bezier: # apply interpolation to t for timing control along the bezier path
t_interp = interp_values
one_minus_t = 1 - t_interp
x_positions = one_minus_t ** 2 * start_x_px + 2 * one_minus_t * t_interp * mid_x_px + t_interp ** 2 * end_x_px
y_positions = one_minus_t ** 2 * start_y_px + 2 * one_minus_t * t_interp * mid_y_px + t_interp ** 2 * end_y_px
tangent_x = 2 * one_minus_t * (mid_x_px - start_x_px) + 2 * t_interp * (end_x_px - mid_x_px)
tangent_y = 2 * one_minus_t * (mid_y_px - start_y_px) + 2 * t_interp * (end_y_px - mid_y_px)
else: # calculate base x and y positions for each frame (center track)
x_positions = start_x_px + (end_x_px - start_x_px) * interp_values
y_positions = start_y_px + (end_y_px - start_y_px) * interp_values
# For non-bezier, tangent is constant (direction from start to end)
tangent_x = torch.full_like(t, end_x_px - start_x_px)
tangent_y = torch.full_like(t, end_y_px - start_y_px)
track_list = []
for frame_idx in range(num_frames):
# Calculate perpendicular direction at this frame
tx = tangent_x[frame_idx].item()
ty = tangent_y[frame_idx].item()
length = (tx ** 2 + ty ** 2) ** 0.5
if length > 0: # Perpendicular unit vector (rotate 90 degrees)
perp_x = -ty / length
perp_y = tx / length
else: # If tangent is zero, spread horizontally
perp_x = 1.0
perp_y = 0.0
frame_tracks = []
for track_idx in range(num_tracks): # center tracks around the main path offset ranges from -(num_tracks-1)/2 to +(num_tracks-1)/2
offset = (track_idx - (num_tracks - 1) / 2) * track_spread_px
track_x = x_positions[frame_idx].item() + perp_x * offset
track_y = y_positions[frame_idx].item() + perp_y * offset
frame_tracks.append([track_x, track_y])
track_list.append(frame_tracks)
tracks = torch.tensor(track_list, dtype=torch.float32, device=device) # [frames, num_tracks, 2]
if track_mask is None:
track_visibility = torch.ones((track_length, num_tracks), dtype=torch.bool, device=device)
else:
track_visibility = (track_mask > 0).any(dim=(1, 2)).unsqueeze(-1)
out_track_info = {}
out_track_info["track_path"] = tracks
out_track_info["track_visibility"] = track_visibility
return io.NodeOutput(out_track_info, track_length)
class WanMoveConcatTrack(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanMoveConcatTrack",
category="conditioning/video_models",
inputs=[
io.Tracks.Input("tracks_1"),
io.Tracks.Input("tracks_2", optional=True),
],
outputs=[
io.Tracks.Output(),
],
)
@classmethod
def execute(cls, tracks_1=None, tracks_2=None) -> io.NodeOutput:
if tracks_2 is None:
return io.NodeOutput(tracks_1)
tracks_out = torch.cat([tracks_1["track_path"], tracks_2["track_path"]], dim=1) # Concatenate along the track dimension
mask_out = torch.cat([tracks_1["track_visibility"], tracks_2["track_visibility"]], dim=-1)
out_track_info = {}
out_track_info["track_path"] = tracks_out
out_track_info["track_visibility"] = mask_out
return io.NodeOutput(out_track_info)
class WanMoveTrackToVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanMoveTrackToVideo",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Tracks.Input("tracks", optional=True),
io.Float.Input("strength", default=1.0, min=0.0, max=100.0, step=0.01, tooltip="Strength of the track conditioning."),
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Image.Input("start_image"),
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, strength, tracks=None, start_image=None, clip_vision_output=None) -> io.NodeOutput:
device=comfy.model_management.intermediate_device()
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=device)
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
image = torch.ones((length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5
image[:start_image.shape[0]] = start_image
concat_latent_image = vae.encode(image[:, :, :, :3])
mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
if tracks is not None and strength > 0.0:
tracks_path = tracks["track_path"][:length] # [T, N, 2]
num_tracks = tracks_path.shape[-2]
track_visibility = tracks.get("track_visibility", torch.ones((length, num_tracks), dtype=torch.bool, device=device))
track_pos = create_pos_embeddings(tracks_path, track_visibility, [4, 8, 8], height, width, track_num=num_tracks)
track_pos = comfy.utils.resize_to_batch_size(track_pos.unsqueeze(0), batch_size)
concat_latent_image_pos = replace_feature(concat_latent_image, track_pos, strength)
else:
concat_latent_image_pos = concat_latent_image
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image_pos, "concat_mask": mask})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
out_latent = {}
out_latent["samples"] = latent
return io.NodeOutput(positive, negative, out_latent)
class WanMoveExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
WanMoveTrackToVideo,
WanMoveTracksFromCoords,
WanMoveConcatTrack,
WanMoveVisualizeTracks,
GenerateTracks,
]
async def comfy_entrypoint() -> WanMoveExtension:
return WanMoveExtension()

View File

@ -13,6 +13,7 @@ import asyncio
import torch import torch
import comfy.model_management import comfy.model_management
from latent_preview import set_preview_method
import nodes import nodes
from comfy_execution.caching import ( from comfy_execution.caching import (
BasicCache, BasicCache,
@ -669,6 +670,8 @@ class PromptExecutor:
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs)) asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
set_preview_method(extra_data.get("preview_method"))
nodes.interrupt_processing(False) nodes.interrupt_processing(False)
if "client_id" in extra_data: if "client_id" in extra_data:

View File

@ -8,6 +8,8 @@ import folder_paths
import comfy.utils import comfy.utils
import logging import logging
default_preview_method = args.preview_method
MAX_PREVIEW_RESOLUTION = args.preview_size MAX_PREVIEW_RESOLUTION = args.preview_size
VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"] VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"]
@ -125,3 +127,11 @@ def prepare_callback(model, steps, x0_output_dict=None):
pbar.update_absolute(step + 1, total_steps, preview_bytes) pbar.update_absolute(step + 1, total_steps, preview_bytes)
return callback return callback
def set_preview_method(override: str = None):
if override and override != "default":
method = LatentPreviewMethod.from_string(override)
if method is not None:
args.preview_method = method
return
args.preview_method = default_preview_method

View File

@ -1 +1 @@
comfyui_manager==4.0.3b4 comfyui_manager==4.0.3b5

View File

@ -2358,6 +2358,7 @@ async def init_builtin_extra_nodes():
"nodes_logic.py", "nodes_logic.py",
"nodes_nop.py", "nodes_nop.py",
"nodes_kandinsky5.py", "nodes_kandinsky5.py",
"nodes_wanmove.py",
] ]
import_failed = [] import_failed = []
@ -2383,7 +2384,6 @@ async def init_builtin_api_nodes():
"nodes_recraft.py", "nodes_recraft.py",
"nodes_pixverse.py", "nodes_pixverse.py",
"nodes_stability.py", "nodes_stability.py",
"nodes_pika.py",
"nodes_runway.py", "nodes_runway.py",
"nodes_sora.py", "nodes_sora.py",
"nodes_topaz.py", "nodes_topaz.py",

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.33.13 comfyui-frontend-package==1.34.9
comfyui-workflow-templates==0.7.54 comfyui-workflow-templates==0.7.59
comfyui-embedded-docs==0.3.1 comfyui-embedded-docs==0.3.1
torch torch
torchsde torchsde

View File

@ -0,0 +1,352 @@
"""
Unit tests for Queue-specific Preview Method Override feature.
Tests the preview method override functionality:
- LatentPreviewMethod.from_string() method
- set_preview_method() function in latent_preview.py
- default_preview_method variable
- Integration with args.preview_method
"""
import pytest
from comfy.cli_args import args, LatentPreviewMethod
from latent_preview import set_preview_method, default_preview_method
class TestLatentPreviewMethodFromString:
"""Test LatentPreviewMethod.from_string() classmethod."""
@pytest.mark.parametrize("value,expected", [
("auto", LatentPreviewMethod.Auto),
("latent2rgb", LatentPreviewMethod.Latent2RGB),
("taesd", LatentPreviewMethod.TAESD),
("none", LatentPreviewMethod.NoPreviews),
])
def test_valid_values_return_enum(self, value, expected):
"""Valid string values should return corresponding enum."""
assert LatentPreviewMethod.from_string(value) == expected
@pytest.mark.parametrize("invalid", [
"invalid",
"TAESD", # Case sensitive
"AUTO", # Case sensitive
"Latent2RGB", # Case sensitive
"latent",
"",
"default", # default is special, not a method
])
def test_invalid_values_return_none(self, invalid):
"""Invalid string values should return None."""
assert LatentPreviewMethod.from_string(invalid) is None
class TestLatentPreviewMethodEnumValues:
"""Test LatentPreviewMethod enum has expected values."""
def test_enum_values(self):
"""Verify enum values match expected strings."""
assert LatentPreviewMethod.NoPreviews.value == "none"
assert LatentPreviewMethod.Auto.value == "auto"
assert LatentPreviewMethod.Latent2RGB.value == "latent2rgb"
assert LatentPreviewMethod.TAESD.value == "taesd"
def test_enum_count(self):
"""Verify exactly 4 preview methods exist."""
assert len(LatentPreviewMethod) == 4
class TestSetPreviewMethod:
"""Test set_preview_method() function from latent_preview.py."""
def setup_method(self):
"""Store original value before each test."""
self.original = args.preview_method
def teardown_method(self):
"""Restore original value after each test."""
args.preview_method = self.original
def test_override_with_taesd(self):
"""'taesd' should set args.preview_method to TAESD."""
set_preview_method("taesd")
assert args.preview_method == LatentPreviewMethod.TAESD
def test_override_with_latent2rgb(self):
"""'latent2rgb' should set args.preview_method to Latent2RGB."""
set_preview_method("latent2rgb")
assert args.preview_method == LatentPreviewMethod.Latent2RGB
def test_override_with_auto(self):
"""'auto' should set args.preview_method to Auto."""
set_preview_method("auto")
assert args.preview_method == LatentPreviewMethod.Auto
def test_override_with_none_value(self):
"""'none' should set args.preview_method to NoPreviews."""
set_preview_method("none")
assert args.preview_method == LatentPreviewMethod.NoPreviews
def test_default_restores_original(self):
"""'default' should restore to default_preview_method."""
# First override to something else
set_preview_method("taesd")
assert args.preview_method == LatentPreviewMethod.TAESD
# Then use 'default' to restore
set_preview_method("default")
assert args.preview_method == default_preview_method
def test_none_param_restores_original(self):
"""None parameter should restore to default_preview_method."""
# First override to something else
set_preview_method("taesd")
assert args.preview_method == LatentPreviewMethod.TAESD
# Then use None to restore
set_preview_method(None)
assert args.preview_method == default_preview_method
def test_empty_string_restores_original(self):
"""Empty string should restore to default_preview_method."""
set_preview_method("taesd")
set_preview_method("")
assert args.preview_method == default_preview_method
def test_invalid_value_restores_original(self):
"""Invalid value should restore to default_preview_method."""
set_preview_method("taesd")
set_preview_method("invalid_method")
assert args.preview_method == default_preview_method
def test_case_sensitive_invalid_restores(self):
"""Case-mismatched values should restore to default."""
set_preview_method("taesd")
set_preview_method("TAESD") # Wrong case
assert args.preview_method == default_preview_method
class TestDefaultPreviewMethod:
"""Test default_preview_method module variable."""
def test_default_is_not_none(self):
"""default_preview_method should not be None."""
assert default_preview_method is not None
def test_default_is_enum_member(self):
"""default_preview_method should be a LatentPreviewMethod enum."""
assert isinstance(default_preview_method, LatentPreviewMethod)
def test_default_matches_args_initial(self):
"""default_preview_method should match CLI default or user setting."""
# This tests that default_preview_method was captured at module load
# After set_preview_method(None), args should equal default
original = args.preview_method
set_preview_method("taesd")
set_preview_method(None)
assert args.preview_method == default_preview_method
args.preview_method = original
class TestArgsPreviewMethodModification:
"""Test args.preview_method can be modified correctly."""
def setup_method(self):
"""Store original value before each test."""
self.original = args.preview_method
def teardown_method(self):
"""Restore original value after each test."""
args.preview_method = self.original
def test_args_accepts_all_enum_values(self):
"""args.preview_method should accept all LatentPreviewMethod values."""
for method in LatentPreviewMethod:
args.preview_method = method
assert args.preview_method == method
def test_args_modification_and_restoration(self):
"""args.preview_method should be modifiable and restorable."""
original = args.preview_method
args.preview_method = LatentPreviewMethod.TAESD
assert args.preview_method == LatentPreviewMethod.TAESD
args.preview_method = original
assert args.preview_method == original
class TestExecutionFlow:
"""Test the execution flow pattern used in execution.py."""
def setup_method(self):
"""Store original value before each test."""
self.original = args.preview_method
def teardown_method(self):
"""Restore original value after each test."""
args.preview_method = self.original
def test_sequential_executions_with_different_methods(self):
"""Simulate multiple queue executions with different preview methods."""
# Execution 1: taesd
set_preview_method("taesd")
assert args.preview_method == LatentPreviewMethod.TAESD
# Execution 2: none
set_preview_method("none")
assert args.preview_method == LatentPreviewMethod.NoPreviews
# Execution 3: default (restore)
set_preview_method("default")
assert args.preview_method == default_preview_method
# Execution 4: auto
set_preview_method("auto")
assert args.preview_method == LatentPreviewMethod.Auto
# Execution 5: no override (None)
set_preview_method(None)
assert args.preview_method == default_preview_method
def test_override_then_default_pattern(self):
"""Test the pattern: override -> execute -> next call restores."""
# First execution with override
set_preview_method("latent2rgb")
assert args.preview_method == LatentPreviewMethod.Latent2RGB
# Second execution without override restores default
set_preview_method(None)
assert args.preview_method == default_preview_method
def test_extra_data_simulation(self):
"""Simulate extra_data.get('preview_method') patterns."""
# Simulate: extra_data = {"preview_method": "taesd"}
extra_data = {"preview_method": "taesd"}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == LatentPreviewMethod.TAESD
# Simulate: extra_data = {}
extra_data = {}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == default_preview_method
# Simulate: extra_data = {"preview_method": "default"}
extra_data = {"preview_method": "default"}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == default_preview_method
class TestRealWorldScenarios:
"""Tests using real-world prompt data patterns."""
def setup_method(self):
"""Store original value before each test."""
self.original = args.preview_method
def teardown_method(self):
"""Restore original value after each test."""
args.preview_method = self.original
def test_captured_prompt_without_preview_method(self):
"""
Test with captured prompt that has no preview_method.
Based on: tests-unit/execution_test/fixtures/default_prompt.json
"""
# Real captured extra_data structure (preview_method absent)
extra_data = {
"extra_pnginfo": {"workflow": {}},
"client_id": "271314f0dabd48e5aaa488ed7a4ceb0d",
"create_time": 1765416558179
}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == default_preview_method
def test_captured_prompt_with_preview_method_taesd(self):
"""Test captured prompt with preview_method: taesd."""
extra_data = {
"extra_pnginfo": {"workflow": {}},
"client_id": "271314f0dabd48e5aaa488ed7a4ceb0d",
"preview_method": "taesd"
}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == LatentPreviewMethod.TAESD
def test_captured_prompt_with_preview_method_none(self):
"""Test captured prompt with preview_method: none (disable preview)."""
extra_data = {
"extra_pnginfo": {"workflow": {}},
"client_id": "test-client",
"preview_method": "none"
}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == LatentPreviewMethod.NoPreviews
def test_captured_prompt_with_preview_method_latent2rgb(self):
"""Test captured prompt with preview_method: latent2rgb."""
extra_data = {
"extra_pnginfo": {"workflow": {}},
"client_id": "test-client",
"preview_method": "latent2rgb"
}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == LatentPreviewMethod.Latent2RGB
def test_captured_prompt_with_preview_method_auto(self):
"""Test captured prompt with preview_method: auto."""
extra_data = {
"extra_pnginfo": {"workflow": {}},
"client_id": "test-client",
"preview_method": "auto"
}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == LatentPreviewMethod.Auto
def test_captured_prompt_with_preview_method_default(self):
"""Test captured prompt with preview_method: default (use CLI setting)."""
# First set to something else
set_preview_method("taesd")
assert args.preview_method == LatentPreviewMethod.TAESD
# Then simulate a prompt with "default"
extra_data = {
"extra_pnginfo": {"workflow": {}},
"client_id": "test-client",
"preview_method": "default"
}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == default_preview_method
def test_sequential_queue_with_different_preview_methods(self):
"""
Simulate real queue scenario: multiple prompts with different settings.
This tests the actual usage pattern in ComfyUI.
"""
# Queue 1: User wants TAESD preview
extra_data_1 = {"client_id": "client-1", "preview_method": "taesd"}
set_preview_method(extra_data_1.get("preview_method"))
assert args.preview_method == LatentPreviewMethod.TAESD
# Queue 2: User wants no preview (faster execution)
extra_data_2 = {"client_id": "client-2", "preview_method": "none"}
set_preview_method(extra_data_2.get("preview_method"))
assert args.preview_method == LatentPreviewMethod.NoPreviews
# Queue 3: User doesn't specify (use server default)
extra_data_3 = {"client_id": "client-3"}
set_preview_method(extra_data_3.get("preview_method"))
assert args.preview_method == default_preview_method
# Queue 4: User explicitly wants default
extra_data_4 = {"client_id": "client-4", "preview_method": "default"}
set_preview_method(extra_data_4.get("preview_method"))
assert args.preview_method == default_preview_method
# Queue 5: User wants latent2rgb
extra_data_5 = {"client_id": "client-5", "preview_method": "latent2rgb"}
set_preview_method(extra_data_5.get("preview_method"))
assert args.preview_method == LatentPreviewMethod.Latent2RGB

View File

@ -0,0 +1,358 @@
"""
E2E tests for Queue-specific Preview Method Override feature.
Tests actual execution with different preview_method values.
Requires a running ComfyUI server with models.
Usage:
COMFYUI_SERVER=http://localhost:8988 pytest test_preview_method_e2e.py -v -m preview_method
Note:
These tests execute actual image generation and wait for completion.
Tests verify preview image transmission based on preview_method setting.
"""
import os
import json
import pytest
import uuid
import time
import random
import websocket
import urllib.request
from pathlib import Path
# Server configuration
SERVER_URL = os.environ.get("COMFYUI_SERVER", "http://localhost:8988")
SERVER_HOST = SERVER_URL.replace("http://", "").replace("https://", "")
# Use existing inference graph fixture
GRAPH_FILE = Path(__file__).parent.parent / "inference" / "graphs" / "default_graph_sdxl1_0.json"
def is_server_running() -> bool:
"""Check if ComfyUI server is running."""
try:
request = urllib.request.Request(f"{SERVER_URL}/system_stats")
with urllib.request.urlopen(request, timeout=2.0):
return True
except Exception:
return False
def prepare_graph_for_test(graph: dict, steps: int = 5) -> dict:
"""Prepare graph for testing: randomize seeds and reduce steps."""
adapted = json.loads(json.dumps(graph)) # Deep copy
for node_id, node in adapted.items():
inputs = node.get("inputs", {})
# Handle both "seed" and "noise_seed" (used by KSamplerAdvanced)
if "seed" in inputs:
inputs["seed"] = random.randint(0, 2**32 - 1)
if "noise_seed" in inputs:
inputs["noise_seed"] = random.randint(0, 2**32 - 1)
# Reduce steps for faster testing (default 20 -> 5)
if "steps" in inputs:
inputs["steps"] = steps
return adapted
# Alias for backward compatibility
randomize_seed = prepare_graph_for_test
class PreviewMethodClient:
"""Client for testing preview_method with WebSocket execution tracking."""
def __init__(self, server_address: str):
self.server_address = server_address
self.client_id = str(uuid.uuid4())
self.ws = None
def connect(self):
"""Connect to WebSocket."""
self.ws = websocket.WebSocket()
self.ws.settimeout(120) # 2 minute timeout for sampling
self.ws.connect(f"ws://{self.server_address}/ws?clientId={self.client_id}")
def close(self):
"""Close WebSocket connection."""
if self.ws:
self.ws.close()
def queue_prompt(self, prompt: dict, extra_data: dict = None) -> dict:
"""Queue a prompt and return response with prompt_id."""
data = {
"prompt": prompt,
"client_id": self.client_id,
"extra_data": extra_data or {}
}
req = urllib.request.Request(
f"http://{self.server_address}/prompt",
data=json.dumps(data).encode("utf-8"),
headers={"Content-Type": "application/json"}
)
return json.loads(urllib.request.urlopen(req).read())
def wait_for_execution(self, prompt_id: str, timeout: float = 120.0) -> dict:
"""
Wait for execution to complete via WebSocket.
Returns:
dict with keys: completed, error, preview_count, execution_time
"""
result = {
"completed": False,
"error": None,
"preview_count": 0,
"execution_time": 0.0
}
start_time = time.time()
self.ws.settimeout(timeout)
try:
while True:
out = self.ws.recv()
elapsed = time.time() - start_time
if isinstance(out, str):
message = json.loads(out)
msg_type = message.get("type")
data = message.get("data", {})
if data.get("prompt_id") != prompt_id:
continue
if msg_type == "executing":
if data.get("node") is None:
# Execution complete
result["completed"] = True
result["execution_time"] = elapsed
break
elif msg_type == "execution_error":
result["error"] = data
result["execution_time"] = elapsed
break
elif msg_type == "progress":
# Progress update during sampling
pass
elif isinstance(out, bytes):
# Binary data = preview image
result["preview_count"] += 1
except websocket.WebSocketTimeoutException:
result["error"] = "Timeout waiting for execution"
result["execution_time"] = time.time() - start_time
return result
def load_graph() -> dict:
"""Load the SDXL graph fixture with randomized seed."""
with open(GRAPH_FILE) as f:
graph = json.load(f)
return randomize_seed(graph) # Avoid caching
# Skip all tests if server is not running
pytestmark = [
pytest.mark.skipif(
not is_server_running(),
reason=f"ComfyUI server not running at {SERVER_URL}"
),
pytest.mark.preview_method,
pytest.mark.execution,
]
@pytest.fixture
def client():
"""Create and connect a test client."""
c = PreviewMethodClient(SERVER_HOST)
c.connect()
yield c
c.close()
@pytest.fixture
def graph():
"""Load the test graph."""
return load_graph()
class TestPreviewMethodExecution:
"""Test actual execution with different preview methods."""
def test_execution_with_latent2rgb(self, client, graph):
"""
Execute with preview_method=latent2rgb.
Should complete and potentially receive preview images.
"""
extra_data = {"preview_method": "latent2rgb"}
response = client.queue_prompt(graph, extra_data)
assert "prompt_id" in response
result = client.wait_for_execution(response["prompt_id"])
# Should complete (may error if model missing, but that's separate)
assert result["completed"] or result["error"] is not None
# Execution should take some time (sampling)
if result["completed"]:
assert result["execution_time"] > 0.5, "Execution too fast - likely didn't run"
# latent2rgb should produce previews
print(f"latent2rgb: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
def test_execution_with_taesd(self, client, graph):
"""
Execute with preview_method=taesd.
TAESD provides higher quality previews.
"""
extra_data = {"preview_method": "taesd"}
response = client.queue_prompt(graph, extra_data)
assert "prompt_id" in response
result = client.wait_for_execution(response["prompt_id"])
assert result["completed"] or result["error"] is not None
if result["completed"]:
assert result["execution_time"] > 0.5
# taesd should also produce previews
print(f"taesd: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
def test_execution_with_none_preview(self, client, graph):
"""
Execute with preview_method=none.
No preview images should be generated.
"""
extra_data = {"preview_method": "none"}
response = client.queue_prompt(graph, extra_data)
assert "prompt_id" in response
result = client.wait_for_execution(response["prompt_id"])
assert result["completed"] or result["error"] is not None
if result["completed"]:
# With "none", should receive no preview images
assert result["preview_count"] == 0, \
f"Expected no previews with 'none', got {result['preview_count']}"
print(f"none: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
def test_execution_with_default(self, client, graph):
"""
Execute with preview_method=default.
Should use server's CLI default setting.
"""
extra_data = {"preview_method": "default"}
response = client.queue_prompt(graph, extra_data)
assert "prompt_id" in response
result = client.wait_for_execution(response["prompt_id"])
assert result["completed"] or result["error"] is not None
if result["completed"]:
print(f"default: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
def test_execution_without_preview_method(self, client, graph):
"""
Execute without preview_method in extra_data.
Should use server's default preview method.
"""
extra_data = {} # No preview_method
response = client.queue_prompt(graph, extra_data)
assert "prompt_id" in response
result = client.wait_for_execution(response["prompt_id"])
assert result["completed"] or result["error"] is not None
if result["completed"]:
print(f"(no override): {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
class TestPreviewMethodComparison:
"""Compare preview behavior between different methods."""
def test_none_vs_latent2rgb_preview_count(self, client, graph):
"""
Compare preview counts: 'none' should have 0, others should have >0.
This is the key verification that preview_method actually works.
"""
results = {}
# Run with none (randomize seed to avoid caching)
graph_none = randomize_seed(graph)
extra_data_none = {"preview_method": "none"}
response = client.queue_prompt(graph_none, extra_data_none)
results["none"] = client.wait_for_execution(response["prompt_id"])
# Run with latent2rgb (randomize seed again)
graph_rgb = randomize_seed(graph)
extra_data_rgb = {"preview_method": "latent2rgb"}
response = client.queue_prompt(graph_rgb, extra_data_rgb)
results["latent2rgb"] = client.wait_for_execution(response["prompt_id"])
# Verify both completed
assert results["none"]["completed"], f"'none' execution failed: {results['none']['error']}"
assert results["latent2rgb"]["completed"], f"'latent2rgb' execution failed: {results['latent2rgb']['error']}"
# Key assertion: 'none' should have 0 previews
assert results["none"]["preview_count"] == 0, \
f"'none' should have 0 previews, got {results['none']['preview_count']}"
# 'latent2rgb' should have at least 1 preview (depends on steps)
assert results["latent2rgb"]["preview_count"] > 0, \
f"'latent2rgb' should have >0 previews, got {results['latent2rgb']['preview_count']}"
print("\nPreview count comparison:") # noqa: T201
print(f" none: {results['none']['preview_count']} previews") # noqa: T201
print(f" latent2rgb: {results['latent2rgb']['preview_count']} previews") # noqa: T201
class TestPreviewMethodSequential:
"""Test sequential execution with different preview methods."""
def test_sequential_different_methods(self, client, graph):
"""
Execute multiple prompts sequentially with different preview methods.
Each should complete independently with correct preview behavior.
"""
methods = ["latent2rgb", "none", "default"]
results = []
for method in methods:
# Randomize seed for each execution to avoid caching
graph_run = randomize_seed(graph)
extra_data = {"preview_method": method}
response = client.queue_prompt(graph_run, extra_data)
result = client.wait_for_execution(response["prompt_id"])
results.append({
"method": method,
"completed": result["completed"],
"preview_count": result["preview_count"],
"execution_time": result["execution_time"],
"error": result["error"]
})
# All should complete or have clear errors
for r in results:
assert r["completed"] or r["error"] is not None, \
f"Method {r['method']} neither completed nor errored"
# "none" should have zero previews if completed
none_result = next(r for r in results if r["method"] == "none")
if none_result["completed"]:
assert none_result["preview_count"] == 0, \
f"'none' should have 0 previews, got {none_result['preview_count']}"
print("\nSequential execution results:") # noqa: T201
for r in results:
status = "" if r["completed"] else f"✗ ({r['error']})"
print(f" {r['method']}: {status}, {r['preview_count']} previews, {r['execution_time']:.2f}s") # noqa: T201