Compare commits

...

37 Commits

Author SHA1 Message Date
Benjamin Lu
9304e47351
Update workflows for new release process (#11064)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.9) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
* Update release workflows for branch process

* Adjust branch order in workflow triggers

* Revert changes in test workflows
2025-12-15 23:24:18 -08:00
comfyanonymous
bc606d7d64
Add a way to set the default ref method in the qwen image code. (#11349) 2025-12-16 01:26:55 -05:00
comfyanonymous
645ee1881e
Inpainting for z image fun control. Use the ZImageFunControlnet node. (#11346)
image -> control image ex: pose
inpaint_image -> image for inpainting
mask -> inpaint mask
2025-12-15 23:38:12 -05:00
Christian Byrne
3d082c3206
bump comfyui-frontend-package to 1.34.9 (patch) (#11342) 2025-12-15 23:35:37 -05:00
comfyanonymous
683569de55
Only enable fp16 on ZImage on newer pytorch. (#11344) 2025-12-15 22:33:27 -05:00
Haoming
ea2c117bc3
[BlockInfo] Wan (#10845)
* block info

* animate

* tensor

* device

* revert
2025-12-15 17:59:16 -08:00
Haoming
fc4af86068
[BlockInfo] Lumina (#11227)
* block info

* device

* Make tensor int again

---------

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2025-12-15 17:57:28 -08:00
comfyanonymous
41bcf0619d
Add code to detect if a z image fun controlnet is broken or not. (#11341) 2025-12-15 20:51:06 -05:00
seed93
d02d0e5744
[add] tripo3.0 (#10663)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.9) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
* [add] tripo3.0

* [tripo] change paramter order

* change order

---------

Co-authored-by: liangd <liangding@vastai3d.com>
2025-12-15 17:38:46 -08:00
comfyanonymous
70541d4e77
Support the new qwen edit 2511 reference method. (#11340)
index_timestep_zero can be selected in the
FluxKontextMultiReferenceLatentMethod now with the display name set to the
more generic "Edit Model Reference Method" node.
2025-12-15 19:20:34 -05:00
drozbay
77b2f7c228
Add context windows callback for custom cond handling (#11208)
Co-authored-by: ozbayb <17261091+ozbayb@users.noreply.github.com>
2025-12-15 16:06:32 -08:00
Alexander Piskun
43e0d4e3cc
comfy_api: remove usage of "Type","List" and "Dict" types (#11238) 2025-12-15 16:01:10 -08:00
Dr.Lt.Data
dbd330454a
feat(preview): add per-queue live preview method override (#11261)
- Add set_preview_method() to override live preview method per queue item
- Read extra_data.preview_method from /prompt request
- Support values: taesd, latent2rgb, none, auto, default
- "default" or unset uses server's CLI --preview-method setting
- Add 44 tests (37 unit + 7 E2E)
2025-12-15 15:57:39 -08:00
Alexander Piskun
33c7f1179d
drop Pika API nodes (#11306) 2025-12-15 15:32:29 -08:00
Alexander Piskun
af91eb6c99
api-nodes: drop Kling v1 model (#11307) 2025-12-15 15:30:24 -08:00
comfyanonymous
5cb1e0c9a0
Disable guards on transformer_options when torch.compile (#11317) 2025-12-15 16:49:29 -05:00
ComfyUI Wiki
51347f9fb8
chore: update workflow templates to v0.7.59 (#11337) 2025-12-15 16:28:55 -05:00
Dr.Lt.Data
a5e85017d8
bump manager requirments to the 4.0.3b5 (#11324) 2025-12-15 14:24:01 -05:00
comfyanonymous
5ac3b26a7d
Update warning for old pytorch version. (#11319)
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Has been cancelled
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Has been cancelled
Execution Tests / test (macos-latest) (push) Has been cancelled
Execution Tests / test (ubuntu-latest) (push) Has been cancelled
Execution Tests / test (windows-latest) (push) Has been cancelled
Test server launches without errors / test (push) Has been cancelled
Unit Tests / test (macos-latest) (push) Has been cancelled
Unit Tests / test (ubuntu-latest) (push) Has been cancelled
Unit Tests / test (windows-2022) (push) Has been cancelled
Versions below 2.4 are no longer supported. We will not break support on purpose but will not fix it if we do.
2025-12-14 04:02:50 -05:00
chaObserv
6592bffc60
seeds_2: add phi_2 variant and sampler node (#11309)
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Generate Pydantic Stubs from api.comfy.org / generate-models (push) Has been cancelled
* Add phi_2 solver type to seeds_2

* Add sampler node of seeds_2
2025-12-14 00:03:29 -05:00
comfyanonymous
971cefe7d4
Fix pytorch warnings. (#11314)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2025-12-13 18:45:23 -05:00
comfyanonymous
da2bfb5b0a
Basic implementation of z image fun control union 2.0 (#11304)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
The inpaint part is currently missing and will be implemented later.

I think they messed up this model pretty bad. They added some
control_noise_refiner blocks but don't actually use them. There is a typo
in their code so instead of doing control_noise_refiner -> control_layers
it runs the whole control_layers twice.

Unfortunately they trained with this typo so the model works but is kind
of slow and would probably perform a lot better if they corrected their
code and trained it again.
2025-12-13 01:39:11 -05:00
comfyanonymous
c5a47a1692
Fix bias dtype issue in mixed ops. (#11293)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2025-12-12 11:49:35 -05:00
Alexander Piskun
908fd7d749
feat(api-nodes): new TextToVideoWithAudio and ImageToVideoWithAudio nodes (#11267)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2025-12-12 00:18:31 -08:00
comfyanonymous
5495589db3
Respect the dtype the op was initialized in for non quant mixed op. (#11282)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2025-12-11 23:32:27 -05:00
Jukka Seppänen
982876d59a
WanMove support (#11247) 2025-12-11 22:29:34 -05:00
comfyanonymous
338d9ae3bb
Make portable updater work with repos in unmerged state. (#11281) 2025-12-11 18:56:33 -05:00
comfyanonymous
eeb020b9b7
Better chroma radiance and other models vram estimation. (#11278)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
2025-12-11 17:33:09 -05:00
comfyanonymous
ae65433a60
This only works on radiance. (#11277) 2025-12-11 17:15:00 -05:00
comfyanonymous
fdebe18296
Fix regular chroma radiance (#11276) 2025-12-11 17:09:35 -05:00
comfyanonymous
f8321eb57b
Adjust memory usage factor. (#11257)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2025-12-11 01:30:31 -05:00
Alexander Piskun
93948e3fc5
feat(api-nodes): enable Kling Omni O1 node (#11229) 2025-12-10 22:11:12 -08:00
Farshore
e711aaf1a7
Lower VAE loading requirements:Create a new branch for GPU memory calculations in qwen-image vae (#11199) 2025-12-10 22:02:26 -05:00
Johnpaul Chiwetelu
57ddb7fd13
Fix: filter hidden files from /internal/files endpoint (#11191) 2025-12-10 21:49:49 -05:00
comfyanonymous
17c92a9f28
Tweak Z Image memory estimation. (#11254) 2025-12-10 19:59:48 -05:00
Alexander Piskun
36357bbcc3
process the NodeV1 dict results correctly (#11237)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2025-12-10 11:55:09 -08:00
Benjamin Lu
f668c2e3c9
bump comfyui-frontend-package to 1.34.8 (#11220)
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.9) (push) Has been cancelled
2025-12-09 22:27:07 -05:00
48 changed files with 2003 additions and 860 deletions

View File

@ -53,6 +53,16 @@ try:
repo.stash(ident)
except KeyError:
print("nothing to stash") # noqa: T201
except:
print("Could not stash, cleaning index and trying again.") # noqa: T201
repo.state_cleanup()
repo.index.read_tree(repo.head.peel().tree)
repo.index.write()
try:
repo.stash(ident)
except KeyError:
print("nothing to stash.") # noqa: T201
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
print("creating backup branch: {}".format(backup_branch_name)) # noqa: T201
try:

View File

@ -5,6 +5,7 @@ on:
push:
branches:
- master
- release/**
paths-ignore:
- 'app/**'
- 'input/**'

View File

@ -2,9 +2,9 @@ name: Execution Tests
on:
push:
branches: [ main, master ]
branches: [ main, master, release/** ]
pull_request:
branches: [ main, master ]
branches: [ main, master, release/** ]
jobs:
test:

View File

@ -2,9 +2,9 @@ name: Test server launches without errors
on:
push:
branches: [ main, master ]
branches: [ main, master, release/** ]
pull_request:
branches: [ main, master ]
branches: [ main, master, release/** ]
jobs:
test:

View File

@ -2,9 +2,9 @@ name: Unit Tests
on:
push:
branches: [ main, master ]
branches: [ main, master, release/** ]
pull_request:
branches: [ main, master ]
branches: [ main, master, release/** ]
jobs:
test:

View File

@ -6,6 +6,7 @@ on:
- "pyproject.toml"
branches:
- master
- release/**
jobs:
update-version:

View File

@ -58,8 +58,13 @@ class InternalRoutes:
return web.json_response({"error": "Invalid directory type"}, status=400)
directory = get_directory_by_type(directory_type)
def is_visible_file(entry: os.DirEntry) -> bool:
"""Filter out hidden files (e.g., .DS_Store on macOS)."""
return entry.is_file() and not entry.name.startswith('.')
sorted_files = sorted(
(entry for entry in os.scandir(directory) if entry.is_file()),
(entry for entry in os.scandir(directory) if is_visible_file(entry)),
key=lambda entry: -entry.stat().st_mtime
)
return web.json_response([entry.name for entry in sorted_files], status=200)

View File

@ -97,6 +97,13 @@ class LatentPreviewMethod(enum.Enum):
Latent2RGB = "latent2rgb"
TAESD = "taesd"
@classmethod
def from_string(cls, value: str):
for member in cls:
if member.value == value:
return member
return None
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")

View File

@ -87,6 +87,7 @@ class IndexListCallbacks:
COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results"
EXECUTE_START = "execute_start"
EXECUTE_CLEANUP = "execute_cleanup"
RESIZE_COND_ITEM = "resize_cond_item"
def init_callbacks(self):
return {}
@ -166,6 +167,18 @@ class IndexListContextHandler(ContextHandlerABC):
new_cond_item = cond_item.copy()
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
for cond_key, cond_value in new_cond_item.items():
# Allow callbacks to handle custom conditioning items
handled = False
for callback in comfy.patcher_extension.get_all_callbacks(
IndexListCallbacks.RESIZE_COND_ITEM, self.callbacks
):
result = callback(cond_key, cond_value, window, x_in, device, new_cond_item)
if result is not None:
new_cond_item[cond_key] = result
handled = True
break
if handled:
continue
if isinstance(cond_value, torch.Tensor):
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):

View File

@ -1557,10 +1557,13 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
@torch.no_grad()
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5, solver_type="phi_1"):
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
"""
if solver_type not in {"phi_1", "phi_2"}:
raise ValueError("solver_type must be 'phi_1' or 'phi_2'")
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
@ -1600,8 +1603,14 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
# Step 2
if solver_type == "phi_1":
denoised_d = torch.lerp(denoised, denoised_2, fac)
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
elif solver_type == "phi_2":
b2 = ei_h_phi_2(-h_eta) / r
b1 = ei_h_phi_1(-h_eta) - b2
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b2 * denoised_2)
if inject_noise:
segment_factor = (r - 1) * h * eta
sde_noise = sde_noise * segment_factor.exp()

View File

@ -41,6 +41,11 @@ class ZImage_Control(torch.nn.Module):
ffn_dim_multiplier: float = (8.0 / 3.0),
norm_eps: float = 1e-5,
qk_norm: bool = True,
n_control_layers=6,
control_in_dim=16,
additional_in_dim=0,
broken=False,
refiner_control=False,
dtype=None,
device=None,
operations=None,
@ -49,10 +54,11 @@ class ZImage_Control(torch.nn.Module):
super().__init__()
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
self.additional_in_dim = 0
self.control_in_dim = 16
self.broken = broken
self.additional_in_dim = additional_in_dim
self.control_in_dim = control_in_dim
n_refiner_layers = 2
self.n_control_layers = 6
self.n_control_layers = n_control_layers
self.control_layers = nn.ModuleList(
[
ZImageControlTransformerBlock(
@ -74,10 +80,31 @@ class ZImage_Control(torch.nn.Module):
all_x_embedder = {}
patch_size = 2
f_patch_size = 1
x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True, device=device, dtype=dtype)
x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * (self.control_in_dim + self.additional_in_dim), dim, bias=True, device=device, dtype=dtype)
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
self.refiner_control = refiner_control
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
if self.refiner_control:
self.control_noise_refiner = nn.ModuleList(
[
ZImageControlTransformerBlock(
layer_id,
dim,
n_heads,
n_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
qk_norm,
block_id=layer_id,
operation_settings=operation_settings,
)
for layer_id in range(n_refiner_layers)
]
)
else:
self.control_noise_refiner = nn.ModuleList(
[
JointTransformerBlock(
@ -105,9 +132,29 @@ class ZImage_Control(torch.nn.Module):
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
x_attn_mask = None
if not self.refiner_control:
for layer in self.control_noise_refiner:
control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
return control_context
def forward_noise_refiner_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
if self.refiner_control:
if self.broken:
if layer_id == 0:
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
if layer_id > 0:
out = None
for i in range(1, len(self.control_layers)):
o, control_context = self.control_layers[i](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
if out is None:
out = o
return (out, control_context)
else:
return self.control_noise_refiner[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
else:
return (None, control_context)
def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)

View File

@ -536,6 +536,7 @@ class NextDiT(nn.Module):
bsz = len(x)
pH = pW = self.patch_size
device = x[0].device
orig_x = x
if self.pad_tokens_multiple is not None:
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
@ -572,13 +573,21 @@ class NextDiT(nn.Module):
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
patches = transformer_options.get("patches", {})
# refine context
for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
padded_img_mask = None
for layer in self.noise_refiner:
x_input = x
for i, layer in enumerate(self.noise_refiner):
x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
if "noise_refiner" in patches:
for p in patches["noise_refiner"]:
out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": freqs_cis[:, cap_pos_ids.shape[1]:], "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"})
if "img" in out:
x = out["img"]
padded_full_embed = torch.cat((cap_feats, x), dim=1)
mask = None
@ -622,14 +631,18 @@ class NextDiT(nn.Module):
patches = transformer_options.get("patches", {})
x_is_tensor = isinstance(x, torch.Tensor)
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options)
freqs_cis = freqs_cis.to(img.device)
transformer_options["total_blocks"] = len(self.layers)
transformer_options["block_type"] = "double"
img_input = img
for i, layer in enumerate(self.layers):
transformer_options["block_index"] = i
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
if "double_block" in patches:
for p in patches["double_block"]:
out = p({"img": img[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
if "img" in out:
img[:, cap_size[0]:] = out["img"]
if "txt" in out:

View File

@ -218,8 +218,23 @@ class QwenImageTransformerBlock(nn.Module):
operations=operations,
)
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def _apply_gate(self, x, y, gate, timestep_zero_index=None):
if timestep_zero_index is not None:
return y + torch.cat((x[:, :timestep_zero_index] * gate[0], x[:, timestep_zero_index:] * gate[1]), dim=1)
else:
return torch.addcmul(y, gate, x)
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor, timestep_zero_index=None) -> Tuple[torch.Tensor, torch.Tensor]:
shift, scale, gate = torch.chunk(mod_params, 3, dim=-1)
if timestep_zero_index is not None:
actual_batch = shift.size(0) // 2
shift, shift_0 = shift[:actual_batch], shift[actual_batch:]
scale, scale_0 = scale[:actual_batch], scale[actual_batch:]
gate, gate_0 = gate[:actual_batch], gate[actual_batch:]
reg = torch.addcmul(shift.unsqueeze(1), x[:, :timestep_zero_index], 1 + scale.unsqueeze(1))
zero = torch.addcmul(shift_0.unsqueeze(1), x[:, timestep_zero_index:], 1 + scale_0.unsqueeze(1))
return torch.cat((reg, zero), dim=1), (gate.unsqueeze(1), gate_0.unsqueeze(1))
else:
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
def forward(
@ -229,14 +244,19 @@ class QwenImageTransformerBlock(nn.Module):
encoder_hidden_states_mask: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
timestep_zero_index=None,
transformer_options={},
) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod_params = self.img_mod(temb)
if timestep_zero_index is not None:
temb = temb.chunk(2, dim=0)[0]
txt_mod_params = self.txt_mod(temb)
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1, timestep_zero_index)
del img_mod1
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
del txt_mod1
@ -251,15 +271,15 @@ class QwenImageTransformerBlock(nn.Module):
del img_modulated
del txt_modulated
hidden_states = hidden_states + img_gate1 * img_attn_output
hidden_states = self._apply_gate(img_attn_output, hidden_states, img_gate1, timestep_zero_index)
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
del img_attn_output
del txt_attn_output
del img_gate1
del txt_gate1
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2, timestep_zero_index)
hidden_states = self._apply_gate(self.img_mlp(img_modulated2), hidden_states, img_gate2, timestep_zero_index)
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
@ -302,6 +322,7 @@ class QwenImageTransformer2DModel(nn.Module):
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
default_ref_method="index",
image_model=None,
final_layer=True,
dtype=None,
@ -314,6 +335,7 @@ class QwenImageTransformer2DModel(nn.Module):
self.in_channels = in_channels
self.out_channels = out_channels or in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.default_ref_method = default_ref_method
self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
@ -391,11 +413,14 @@ class QwenImageTransformer2DModel(nn.Module):
hidden_states, img_ids, orig_shape = self.process_img(x)
num_embeds = hidden_states.shape[1]
timestep_zero_index = None
if ref_latents is not None:
h = 0
w = 0
index = 0
index_ref_method = kwargs.get("ref_latents_method", "index") == "index"
ref_method = kwargs.get("ref_latents_method", self.default_ref_method)
index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero")
timestep_zero = ref_method == "index_timestep_zero"
for ref in ref_latents:
if index_ref_method:
index += 1
@ -415,6 +440,10 @@ class QwenImageTransformer2DModel(nn.Module):
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
hidden_states = torch.cat([hidden_states, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
if timestep_zero:
if index > 0:
timestep = torch.cat([timestep, timestep * 0], dim=0)
timestep_zero_index = num_embeds
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
@ -446,7 +475,7 @@ class QwenImageTransformer2DModel(nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"])
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], timestep_zero_index=timestep_zero_index, transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
hidden_states = out["img"]
@ -458,6 +487,7 @@ class QwenImageTransformer2DModel(nn.Module):
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
timestep_zero_index=timestep_zero_index,
transformer_options=transformer_options,
)
@ -474,6 +504,9 @@ class QwenImageTransformer2DModel(nn.Module):
if add is not None:
hidden_states[:, :add.shape[1]] += add
if timestep_zero_index is not None:
temb = temb.chunk(2, dim=0)[0]
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)

View File

@ -568,7 +568,10 @@ class WanModel(torch.nn.Module):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@ -763,7 +766,10 @@ class VaceWanModel(WanModel):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@ -862,7 +868,10 @@ class CameraWanModel(WanModel):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
@ -1326,16 +1335,19 @@ class WanModel_S2V(WanModel):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context)
x = block(x, e=e0, freqs=freqs, context=context, transformer_options=transformer_options)
if audio_emb is not None:
x = self.audio_injector(x, i, audio_emb, audio_emb_global, seq_len)
# head
@ -1574,7 +1586,10 @@ class HumoWanModel(WanModel):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}

View File

@ -523,7 +523,10 @@ class AnimateWanModel(WanModel):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}

View File

@ -261,6 +261,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["nerf_embedder_dtype"] = torch.float32
if "__x0__" in state_dict_keys: # x0 pred
dit_config["use_x0"] = True
else:
dit_config["use_x0"] = False
else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys

View File

@ -454,6 +454,9 @@ class ModelPatcher:
def set_model_post_input_patch(self, patch):
self.set_model_patch(patch, "post_input")
def set_model_noise_refiner_patch(self, patch):
self.set_model_patch(patch, "noise_refiner")
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
rope_options = self.model_options["transformer_options"].get("rope_options", {})
rope_options["scale_x"] = scale_x

View File

@ -497,15 +497,14 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
) -> None:
super().__init__()
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
# self.factory_kwargs = {"device": device, "dtype": dtype}
if dtype is None:
dtype = MixedPrecisionOps._compute_dtype
self.factory_kwargs = {"device": device, "dtype": dtype}
self.in_features = in_features
self.out_features = out_features
if bias:
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
else:
self.register_parameter("bias", None)
self._has_bias = bias
self.tensor_class = None
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
@ -530,7 +529,14 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
layer_conf = json.loads(layer_conf.numpy().tobytes())
if layer_conf is None:
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
dtype = self.factory_kwargs["dtype"]
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=dtype), requires_grad=False)
if dtype != MixedPrecisionOps._compute_dtype:
self.comfy_cast_weights = True
if self._has_bias:
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=dtype))
else:
self.register_parameter("bias", None)
else:
self.quant_format = layer_conf.get("format", None)
if not self._full_precision_mm:
@ -560,6 +566,11 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
requires_grad=False
)
if self._has_bias:
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=MixedPrecisionOps._compute_dtype))
else:
self.register_parameter("bias", None)
for param_name in qconfig["parameters"]:
param_key = f"{prefix}{param_name}"
_v = state_dict.pop(param_key, None)
@ -581,7 +592,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
quant_conf = {"format": self.quant_format}
if self._full_precision_mm:
quant_conf["full_precision_matrix_mult"] = True
sd["{}comfy_quant".format(prefix)] = torch.frombuffer(json.dumps(quant_conf).encode('utf-8'), dtype=torch.uint8)
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
return sd
def _forward(self, input, weight, bias):

View File

@ -549,8 +549,10 @@ class VAE:
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (2200 if shape[2]<=4 else 7000) * shape[3] * shape[4] * (8*8) * model_management.dtype_size(dtype)
# Hunyuan 3d v2 2.0 & 2.1
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:

View File

@ -28,6 +28,7 @@ from . import supported_models_base
from . import latent_formats
from . import diffusers_convert
import comfy.model_management
class SD15(supported_models_base.BASE):
unet_config = {
@ -541,7 +542,7 @@ class SD3(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.SD3
memory_usage_factor = 1.2
memory_usage_factor = 1.6
text_encoder_key_prefix = ["text_encoders."]
@ -965,7 +966,7 @@ class CosmosT2IPredict2(supported_models_base.BASE):
def __init__(self, unet_config):
super().__init__(unet_config)
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.9
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.95
def get_model(self, state_dict, prefix="", device=None):
out = model_base.CosmosPredict2(self, device=device)
@ -1026,9 +1027,15 @@ class ZImage(Lumina2):
"shift": 3.0,
}
memory_usage_factor = 1.7
memory_usage_factor = 2.0
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
supported_inference_dtypes = [torch.bfloat16, torch.float32]
def __init__(self, unet_config):
super().__init__(unet_config)
if comfy.model_management.extended_fp16_support():
self.supported_inference_dtypes = self.supported_inference_dtypes.copy()
self.supported_inference_dtypes.insert(1, torch.float16)
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
@ -1289,7 +1296,7 @@ class ChromaRadiance(Chroma):
latent_format = comfy.latent_formats.ChromaRadiance
# Pixel-space model, no spatial compression for model input.
memory_usage_factor = 0.038
memory_usage_factor = 0.044
def get_model(self, state_dict, prefix="", device=None):
return model_base.ChromaRadiance(self, device=device)
@ -1332,7 +1339,7 @@ class Omnigen2(supported_models_base.BASE):
"shift": 2.6,
}
memory_usage_factor = 1.65 #TODO
memory_usage_factor = 1.95 #TODO
unet_extra_config = {}
latent_format = latent_formats.Flux
@ -1397,7 +1404,7 @@ class HunyuanImage21(HunyuanVideo):
latent_format = latent_formats.HunyuanImage21
memory_usage_factor = 7.7
memory_usage_factor = 8.7
supported_inference_dtypes = [torch.bfloat16, torch.float32]
@ -1488,7 +1495,7 @@ class Kandinsky5(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.HunyuanVideo
memory_usage_factor = 1.1 #TODO
memory_usage_factor = 1.25 #TODO
supported_inference_dtypes = [torch.bfloat16, torch.float32]
@ -1517,7 +1524,7 @@ class Kandinsky5Image(Kandinsky5):
}
latent_format = latent_formats.Flux
memory_usage_factor = 1.1 #TODO
memory_usage_factor = 1.25 #TODO
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Kandinsky5Image(self, device=device)

View File

@ -53,7 +53,7 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
ALWAYS_SAFE_LOAD = True
logging.info("Checkpoint files will always be loaded safely.")
else:
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.")
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if device is None:
@ -1262,6 +1262,6 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
if quant_metadata is not None:
layers = quant_metadata["layers"]
for k, v in layers.items():
state_dict["{}.comfy_quant".format(k)] = torch.frombuffer(json.dumps(v).encode('utf-8'), dtype=torch.uint8)
state_dict["{}.comfy_quant".format(k)] = torch.tensor(list(json.dumps(v).encode('utf-8')), dtype=torch.uint8)
return state_dict, metadata

View File

@ -5,12 +5,12 @@ This module handles capability negotiation between frontend and backend,
allowing graceful protocol evolution while maintaining backward compatibility.
"""
from typing import Any, Dict
from typing import Any
from comfy.cli_args import args
# Default server capabilities
SERVER_FEATURE_FLAGS: Dict[str, Any] = {
SERVER_FEATURE_FLAGS: dict[str, Any] = {
"supports_preview_metadata": True,
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
"extension": {"manager": {"supports_v4": True}},
@ -18,7 +18,7 @@ SERVER_FEATURE_FLAGS: Dict[str, Any] = {
def get_connection_feature(
sockets_metadata: Dict[str, Dict[str, Any]],
sockets_metadata: dict[str, dict[str, Any]],
sid: str,
feature_name: str,
default: Any = False
@ -42,7 +42,7 @@ def get_connection_feature(
def supports_feature(
sockets_metadata: Dict[str, Dict[str, Any]],
sockets_metadata: dict[str, dict[str, Any]],
sid: str,
feature_name: str
) -> bool:
@ -60,7 +60,7 @@ def supports_feature(
return get_connection_feature(sockets_metadata, sid, feature_name, False) is True
def get_server_features() -> Dict[str, Any]:
def get_server_features() -> dict[str, Any]:
"""
Get the server's feature flags.

View File

@ -1,4 +1,4 @@
from typing import Type, List, NamedTuple
from typing import NamedTuple
from comfy_api.internal.singleton import ProxiedSingleton
from packaging import version as packaging_version
@ -10,7 +10,7 @@ class ComfyAPIBase(ProxiedSingleton):
class ComfyAPIWithVersion(NamedTuple):
version: str
api_class: Type[ComfyAPIBase]
api_class: type[ComfyAPIBase]
def parse_version(version_str: str) -> packaging_version.Version:
@ -23,16 +23,16 @@ def parse_version(version_str: str) -> packaging_version.Version:
return packaging_version.parse(version_str)
registered_versions: List[ComfyAPIWithVersion] = []
registered_versions: list[ComfyAPIWithVersion] = []
def register_versions(versions: List[ComfyAPIWithVersion]):
def register_versions(versions: list[ComfyAPIWithVersion]):
versions.sort(key=lambda x: parse_version(x.version))
global registered_versions
registered_versions = versions
def get_all_versions() -> List[ComfyAPIWithVersion]:
def get_all_versions() -> list[ComfyAPIWithVersion]:
"""
Returns a list of all registered ComfyAPI versions.
"""

View File

@ -8,7 +8,7 @@ import os
import textwrap
import threading
from enum import Enum
from typing import Optional, Type, get_origin, get_args, get_type_hints
from typing import Optional, get_origin, get_args, get_type_hints
class TypeTracker:
@ -193,7 +193,7 @@ class AsyncToSyncConverter:
return result_container["result"]
@classmethod
def create_sync_class(cls, async_class: Type, thread_pool_size=10) -> Type:
def create_sync_class(cls, async_class: type, thread_pool_size=10) -> type:
"""
Creates a new class with synchronous versions of all async methods.
@ -563,7 +563,7 @@ class AsyncToSyncConverter:
@classmethod
def _generate_imports(
cls, async_class: Type, type_tracker: TypeTracker
cls, async_class: type, type_tracker: TypeTracker
) -> list[str]:
"""Generate import statements for the stub file."""
imports = []
@ -628,7 +628,7 @@ class AsyncToSyncConverter:
return imports
@classmethod
def _get_class_attributes(cls, async_class: Type) -> list[tuple[str, Type]]:
def _get_class_attributes(cls, async_class: type) -> list[tuple[str, type]]:
"""Extract class attributes that are classes themselves."""
class_attributes = []
@ -654,7 +654,7 @@ class AsyncToSyncConverter:
def _generate_inner_class_stub(
cls,
name: str,
attr: Type,
attr: type,
indent: str = " ",
type_tracker: Optional[TypeTracker] = None,
) -> list[str]:
@ -782,7 +782,7 @@ class AsyncToSyncConverter:
return processed
@classmethod
def generate_stub_file(cls, async_class: Type, sync_class: Type) -> None:
def generate_stub_file(cls, async_class: type, sync_class: type) -> None:
"""
Generate a .pyi stub file for the sync class to help IDEs with type checking.
"""
@ -988,7 +988,7 @@ class AsyncToSyncConverter:
logging.error(traceback.format_exc())
def create_sync_class(async_class: Type, thread_pool_size=10) -> Type:
def create_sync_class(async_class: type, thread_pool_size=10) -> type:
"""
Creates a sync version of an async class

View File

@ -1,4 +1,4 @@
from typing import Type, TypeVar
from typing import TypeVar
class SingletonMetaclass(type):
T = TypeVar("T", bound="SingletonMetaclass")
@ -11,13 +11,13 @@ class SingletonMetaclass(type):
)
return cls._instances[cls]
def inject_instance(cls: Type[T], instance: T) -> None:
def inject_instance(cls: type[T], instance: T) -> None:
assert cls not in SingletonMetaclass._instances, (
"Cannot inject instance after first instantiation"
)
SingletonMetaclass._instances[cls] = instance
def get_instance(cls: Type[T], *args, **kwargs) -> T:
def get_instance(cls: type[T], *args, **kwargs) -> T:
"""
Gets the singleton instance of the class, creating it if it doesn't exist.
"""

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Type, TYPE_CHECKING
from typing import TYPE_CHECKING
from comfy_api.internal import ComfyAPIBase
from comfy_api.internal.singleton import ProxiedSingleton
from comfy_api.internal.async_to_sync import create_sync_class
@ -113,7 +113,7 @@ ComfyAPI = ComfyAPI_latest
if TYPE_CHECKING:
import comfy_api.latest.generated.ComfyAPISyncStub # type: ignore
ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
ComfyAPISync: type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
ComfyAPISync = create_sync_class(ComfyAPI_latest)
# create new aliases for io and ui

View File

@ -1,5 +1,5 @@
import torch
from typing import TypedDict, List, Optional
from typing import TypedDict, Optional
ImageInput = torch.Tensor
"""
@ -39,4 +39,4 @@ class LatentInput(TypedDict):
Optional noise mask tensor in the same format as samples.
"""
batch_index: Optional[List[int]]
batch_index: Optional[list[int]]

View File

@ -774,6 +774,13 @@ class AudioEncoder(ComfyTypeIO):
class AudioEncoderOutput(ComfyTypeIO):
Type = Any
@comfytype(io_type="TRACKS")
class Tracks(ComfyTypeIO):
class TrackDict(TypedDict):
track_path: torch.Tensor
track_visibility: torch.Tensor
Type = TrackDict
@comfytype(io_type="COMFY_MULTITYPED_V3")
class MultiType:
Type = Any
@ -1815,7 +1822,7 @@ class NodeOutput(_NodeOutputInternal):
ui = data["ui"]
if "expand" in data:
expand = data["expand"]
return cls(args=args, ui=ui, expand=expand)
return cls(*args, ui=ui, expand=expand)
def __getitem__(self, index) -> Any:
return self.args[index]
@ -1894,6 +1901,7 @@ __all__ = [
"SEGS",
"AnyType",
"MultiType",
"Tracks",
# Dynamic Types
"MatchType",
# "DynamicCombo",

View File

@ -5,7 +5,6 @@ import os
import random
import uuid
from io import BytesIO
from typing import Type
import av
import numpy as np
@ -83,7 +82,7 @@ class ImageSaveHelper:
return PILImage.fromarray(np.clip(255.0 * image_tensor.cpu().numpy(), 0, 255).astype(np.uint8))
@staticmethod
def _create_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
def _create_png_metadata(cls: type[ComfyNode] | None) -> PngInfo | None:
"""Creates a PngInfo object with prompt and extra_pnginfo."""
if args.disable_metadata or cls is None or not cls.hidden:
return None
@ -96,7 +95,7 @@ class ImageSaveHelper:
return metadata
@staticmethod
def _create_animated_png_metadata(cls: Type[ComfyNode] | None) -> PngInfo | None:
def _create_animated_png_metadata(cls: type[ComfyNode] | None) -> PngInfo | None:
"""Creates a PngInfo object with prompt and extra_pnginfo for animated PNGs (APNG)."""
if args.disable_metadata or cls is None or not cls.hidden:
return None
@ -121,7 +120,7 @@ class ImageSaveHelper:
return metadata
@staticmethod
def _create_webp_metadata(pil_image: PILImage.Image, cls: Type[ComfyNode] | None) -> PILImage.Exif:
def _create_webp_metadata(pil_image: PILImage.Image, cls: type[ComfyNode] | None) -> PILImage.Exif:
"""Creates EXIF metadata bytes for WebP images."""
exif_data = pil_image.getexif()
if args.disable_metadata or cls is None or cls.hidden is None:
@ -137,7 +136,7 @@ class ImageSaveHelper:
@staticmethod
def save_images(
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, compress_level = 4,
images, filename_prefix: str, folder_type: FolderType, cls: type[ComfyNode] | None, compress_level = 4,
) -> list[SavedResult]:
"""Saves a batch of images as individual PNG files."""
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
@ -155,7 +154,7 @@ class ImageSaveHelper:
return results
@staticmethod
def get_save_images_ui(images, filename_prefix: str, cls: Type[ComfyNode] | None, compress_level=4) -> SavedImages:
def get_save_images_ui(images, filename_prefix: str, cls: type[ComfyNode] | None, compress_level=4) -> SavedImages:
"""Saves a batch of images and returns a UI object for the node output."""
return SavedImages(
ImageSaveHelper.save_images(
@ -169,7 +168,7 @@ class ImageSaveHelper:
@staticmethod
def save_animated_png(
images, filename_prefix: str, folder_type: FolderType, cls: Type[ComfyNode] | None, fps: float, compress_level: int
images, filename_prefix: str, folder_type: FolderType, cls: type[ComfyNode] | None, fps: float, compress_level: int
) -> SavedResult:
"""Saves a batch of images as a single animated PNG."""
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
@ -191,7 +190,7 @@ class ImageSaveHelper:
@staticmethod
def get_save_animated_png_ui(
images, filename_prefix: str, cls: Type[ComfyNode] | None, fps: float, compress_level: int
images, filename_prefix: str, cls: type[ComfyNode] | None, fps: float, compress_level: int
) -> SavedImages:
"""Saves an animated PNG and returns a UI object for the node output."""
result = ImageSaveHelper.save_animated_png(
@ -209,7 +208,7 @@ class ImageSaveHelper:
images,
filename_prefix: str,
folder_type: FolderType,
cls: Type[ComfyNode] | None,
cls: type[ComfyNode] | None,
fps: float,
lossless: bool,
quality: int,
@ -238,7 +237,7 @@ class ImageSaveHelper:
def get_save_animated_webp_ui(
images,
filename_prefix: str,
cls: Type[ComfyNode] | None,
cls: type[ComfyNode] | None,
fps: float,
lossless: bool,
quality: int,
@ -267,7 +266,7 @@ class AudioSaveHelper:
audio: dict,
filename_prefix: str,
folder_type: FolderType,
cls: Type[ComfyNode] | None,
cls: type[ComfyNode] | None,
format: str = "flac",
quality: str = "128k",
) -> list[SavedResult]:
@ -372,7 +371,7 @@ class AudioSaveHelper:
@staticmethod
def get_save_audio_ui(
audio, filename_prefix: str, cls: Type[ComfyNode] | None, format: str = "flac", quality: str = "128k",
audio, filename_prefix: str, cls: type[ComfyNode] | None, format: str = "flac", quality: str = "128k",
) -> SavedAudios:
"""Save and instantly wrap for UI."""
return SavedAudios(
@ -388,7 +387,7 @@ class AudioSaveHelper:
class PreviewImage(_UIOutput):
def __init__(self, image: Image.Type, animated: bool = False, cls: Type[ComfyNode] = None, **kwargs):
def __init__(self, image: Image.Type, animated: bool = False, cls: type[ComfyNode] = None, **kwargs):
self.values = ImageSaveHelper.save_images(
image,
filename_prefix="ComfyUI_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for _ in range(5)),
@ -412,7 +411,7 @@ class PreviewMask(PreviewImage):
class PreviewAudio(_UIOutput):
def __init__(self, audio: dict, cls: Type[ComfyNode] = None, **kwargs):
def __init__(self, audio: dict, cls: type[ComfyNode] = None, **kwargs):
self.values = AudioSaveHelper.save_audio(
audio,
filename_prefix="ComfyUI_temp_" + "".join(random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(5)),

View File

@ -2,9 +2,8 @@ from comfy_api.latest import ComfyAPI_latest
from comfy_api.v0_0_2 import ComfyAPIAdapter_v0_0_2
from comfy_api.v0_0_1 import ComfyAPIAdapter_v0_0_1
from comfy_api.internal import ComfyAPIBase
from typing import List, Type
supported_versions: List[Type[ComfyAPIBase]] = [
supported_versions: list[type[ComfyAPIBase]] = [
ComfyAPI_latest,
ComfyAPIAdapter_v0_0_2,
ComfyAPIAdapter_v0_0_1,

View File

@ -51,25 +51,25 @@ class TaskStatusImageResult(BaseModel):
url: str = Field(..., description="URL for generated image")
class OmniTaskStatusResults(BaseModel):
class TaskStatusResults(BaseModel):
videos: list[TaskStatusVideoResult] | None = Field(None)
images: list[TaskStatusImageResult] | None = Field(None)
class OmniTaskStatusResponseData(BaseModel):
class TaskStatusResponseData(BaseModel):
created_at: int | None = Field(None, description="Task creation time")
updated_at: int | None = Field(None, description="Task update time")
task_status: str | None = None
task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.")
task_id: str | None = Field(None, description="Task ID")
task_result: OmniTaskStatusResults | None = Field(None)
task_result: TaskStatusResults | None = Field(None)
class OmniTaskStatusResponse(BaseModel):
class TaskStatusResponse(BaseModel):
code: int | None = Field(None, description="Error code")
message: str | None = Field(None, description="Error message")
request_id: str | None = Field(None, description="Request ID")
data: OmniTaskStatusResponseData | None = Field(None)
data: TaskStatusResponseData | None = Field(None)
class OmniImageParamImage(BaseModel):
@ -84,3 +84,21 @@ class OmniProImageRequest(BaseModel):
mode: str = Field("pro")
n: int | None = Field(1, le=9)
image_list: list[OmniImageParamImage] | None = Field(..., max_length=10)
class TextToVideoWithAudioRequest(BaseModel):
model_name: str = Field(..., description="kling-v2-6")
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
duration: str = Field(..., description="'5' or '10'")
prompt: str = Field(...)
mode: str = Field("pro")
sound: str = Field(..., description="'on' or 'off'")
class ImageToVideoWithAudioRequest(BaseModel):
model_name: str = Field(..., description="kling-v2-6")
image: str = Field(...)
duration: str = Field(..., description="'5' or '10'")
prompt: str = Field(...)
mode: str = Field("pro")
sound: str = Field(..., description="'on' or 'off'")

View File

@ -1,100 +0,0 @@
from typing import Optional
from enum import Enum
from pydantic import BaseModel, Field
class Pikaffect(str, Enum):
Cake_ify = "Cake-ify"
Crumble = "Crumble"
Crush = "Crush"
Decapitate = "Decapitate"
Deflate = "Deflate"
Dissolve = "Dissolve"
Explode = "Explode"
Eye_pop = "Eye-pop"
Inflate = "Inflate"
Levitate = "Levitate"
Melt = "Melt"
Peel = "Peel"
Poke = "Poke"
Squish = "Squish"
Ta_da = "Ta-da"
Tear = "Tear"
class PikaBodyGenerate22C2vGenerate22PikascenesPost(BaseModel):
aspectRatio: Optional[float] = Field(None, description='Aspect ratio (width / height)')
duration: Optional[int] = Field(5)
ingredientsMode: str = Field(...)
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
resolution: Optional[str] = Field('1080p')
seed: Optional[int] = Field(None)
class PikaGenerateResponse(BaseModel):
video_id: str = Field(...)
class PikaBodyGenerate22I2vGenerate22I2vPost(BaseModel):
duration: Optional[int] = 5
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGenerate22KeyframeGenerate22PikaframesPost(BaseModel):
duration: Optional[int] = Field(None, ge=5, le=10)
negativePrompt: Optional[str] = Field(None)
promptText: str = Field(...)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGenerate22T2vGenerate22T2vPost(BaseModel):
aspectRatio: Optional[float] = Field(
1.7777777777777777,
description='Aspect ratio (width / height)',
ge=0.4,
le=2.5,
)
duration: Optional[int] = 5
negativePrompt: Optional[str] = Field(None)
promptText: str = Field(...)
resolution: Optional[str] = '1080p'
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikadditionsGeneratePikadditionsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikaffectsGeneratePikaffectsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
pikaffect: Optional[str] = None
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
class PikaBodyGeneratePikaswapsGeneratePikaswapsPost(BaseModel):
negativePrompt: Optional[str] = Field(None)
promptText: Optional[str] = Field(None)
seed: Optional[int] = Field(None)
modifyRegionRoi: Optional[str] = Field(None)
class PikaStatusEnum(str, Enum):
queued = "queued"
started = "started"
finished = "finished"
failed = "failed"
class PikaVideoResponse(BaseModel):
id: str = Field(...)
progress: Optional[int] = Field(None)
status: PikaStatusEnum
url: Optional[str] = Field(None)

View File

@ -5,11 +5,17 @@ from typing import Optional, List, Dict, Any, Union
from pydantic import BaseModel, Field, RootModel
class TripoModelVersion(str, Enum):
v3_0_20250812 = 'v3.0-20250812'
v2_5_20250123 = 'v2.5-20250123'
v2_0_20240919 = 'v2.0-20240919'
v1_4_20240625 = 'v1.4-20240625'
class TripoGeometryQuality(str, Enum):
standard = 'standard'
detailed = 'detailed'
class TripoTextureQuality(str, Enum):
standard = 'standard'
detailed = 'detailed'
@ -61,14 +67,20 @@ class TripoSpec(str, Enum):
class TripoAnimation(str, Enum):
IDLE = "preset:idle"
WALK = "preset:walk"
RUN = "preset:run"
DIVE = "preset:dive"
CLIMB = "preset:climb"
JUMP = "preset:jump"
RUN = "preset:run"
SLASH = "preset:slash"
SHOOT = "preset:shoot"
HURT = "preset:hurt"
FALL = "preset:fall"
TURN = "preset:turn"
QUADRUPED_WALK = "preset:quadruped:walk"
HEXAPOD_WALK = "preset:hexapod:walk"
OCTOPOD_WALK = "preset:octopod:walk"
SERPENTINE_MARCH = "preset:serpentine:march"
AQUATIC_MARCH = "preset:aquatic:march"
class TripoStylizeStyle(str, Enum):
LEGO = "lego"
@ -105,6 +117,11 @@ class TripoTaskStatus(str, Enum):
BANNED = "banned"
EXPIRED = "expired"
class TripoFbxPreset(str, Enum):
BLENDER = "blender"
MIXAMO = "mixamo"
_3DSMAX = "3dsmax"
class TripoFileTokenReference(BaseModel):
type: Optional[str] = Field(None, description='The type of the reference')
file_token: str
@ -142,6 +159,7 @@ class TripoTextToModelRequest(BaseModel):
model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
style: Optional[TripoStyle] = None
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
@ -156,6 +174,7 @@ class TripoImageToModelRequest(BaseModel):
model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method')
style: Optional[TripoStyle] = Field(None, description='The style to apply to the generated model')
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
@ -173,6 +192,7 @@ class TripoMultiviewToModelRequest(BaseModel):
model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
texture_alignment: Optional[TripoTextureAlignment] = TripoTextureAlignment.ORIGINAL_IMAGE
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
orientation: Optional[TripoOrientation] = Field(TripoOrientation.DEFAULT, description='The orientation for the model')
@ -219,14 +239,24 @@ class TripoConvertModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description='Type of task')
format: TripoConvertFormat = Field(..., description='The format to convert to')
original_model_task_id: str = Field(..., description='The task ID of the original model')
quad: Optional[bool] = Field(False, description='Whether to apply quad to the model')
force_symmetry: Optional[bool] = Field(False, description='Whether to force symmetry')
face_limit: Optional[int] = Field(10000, description='The number of faces to limit the conversion to')
flatten_bottom: Optional[bool] = Field(False, description='Whether to flatten the bottom of the model')
flatten_bottom_threshold: Optional[float] = Field(0.01, description='The threshold for flattening the bottom')
texture_size: Optional[int] = Field(4096, description='The size of the texture')
quad: Optional[bool] = Field(None, description='Whether to apply quad to the model')
force_symmetry: Optional[bool] = Field(None, description='Whether to force symmetry')
face_limit: Optional[int] = Field(None, description='The number of faces to limit the conversion to')
flatten_bottom: Optional[bool] = Field(None, description='Whether to flatten the bottom of the model')
flatten_bottom_threshold: Optional[float] = Field(None, description='The threshold for flattening the bottom')
texture_size: Optional[int] = Field(None, description='The size of the texture')
texture_format: Optional[TripoTextureFormat] = Field(TripoTextureFormat.JPEG, description='The format of the texture')
pivot_to_center_bottom: Optional[bool] = Field(False, description='Whether to pivot to the center bottom')
pivot_to_center_bottom: Optional[bool] = Field(None, description='Whether to pivot to the center bottom')
scale_factor: Optional[float] = Field(None, description='The scale factor for the model')
with_animation: Optional[bool] = Field(None, description='Whether to include animations')
pack_uv: Optional[bool] = Field(None, description='Whether to pack the UVs')
bake: Optional[bool] = Field(None, description='Whether to bake the model')
part_names: Optional[List[str]] = Field(None, description='The names of the parts to include')
fbx_preset: Optional[TripoFbxPreset] = Field(None, description='The preset for the FBX export')
export_vertex_colors: Optional[bool] = Field(None, description='Whether to export the vertex colors')
export_orientation: Optional[TripoOrientation] = Field(None, description='The orientation for the export')
animate_in_place: Optional[bool] = Field(None, description='Whether to animate in place')
class TripoTaskRequest(RootModel):
root: Union[

View File

@ -50,6 +50,7 @@ from comfy_api_nodes.apis import (
KlingSingleImageEffectModelName,
)
from comfy_api_nodes.apis.kling_api import (
ImageToVideoWithAudioRequest,
OmniImageParamImage,
OmniParamImage,
OmniParamVideo,
@ -57,7 +58,8 @@ from comfy_api_nodes.apis.kling_api import (
OmniProImageRequest,
OmniProReferences2VideoRequest,
OmniProText2VideoRequest,
OmniTaskStatusResponse,
TaskStatusResponse,
TextToVideoWithAudioRequest,
)
from comfy_api_nodes.util import (
ApiEndpoint,
@ -103,10 +105,6 @@ AVERAGE_DURATION_VIDEO_EXTEND = 320
MODE_TEXT2VIDEO = {
"standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"),
"standard mode / 10s duration / kling-v1": ("std", "10", "kling-v1"),
"pro mode / 5s duration / kling-v1": ("pro", "5", "kling-v1"),
"pro mode / 10s duration / kling-v1": ("pro", "10", "kling-v1"),
"standard mode / 5s duration / kling-v1-6": ("std", "5", "kling-v1-6"),
"standard mode / 10s duration / kling-v1-6": ("std", "10", "kling-v1-6"),
"pro mode / 5s duration / kling-v2-master": ("pro", "5", "kling-v2-master"),
@ -127,8 +125,6 @@ See: [Kling API Docs Capability Map](https://app.klingai.com/global/dev/document
MODE_START_END_FRAME = {
"standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"),
"pro mode / 5s duration / kling-v1": ("pro", "5", "kling-v1"),
"pro mode / 5s duration / kling-v1-5": ("pro", "5", "kling-v1-5"),
"pro mode / 10s duration / kling-v1-5": ("pro", "10", "kling-v1-5"),
"pro mode / 5s duration / kling-v1-6": ("pro", "5", "kling-v1-6"),
@ -242,7 +238,7 @@ def normalize_omni_prompt_references(prompt: str) -> str:
return re.sub(r"(?<!\w)@video(?P<idx>\d*)(?!\w)", _video_repl, prompt)
async def finish_omni_video_task(cls: type[IO.ComfyNode], response: OmniTaskStatusResponse) -> IO.NodeOutput:
async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusResponse) -> IO.NodeOutput:
if response.code:
raise RuntimeError(
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
@ -250,7 +246,7 @@ async def finish_omni_video_task(cls: type[IO.ComfyNode], response: OmniTaskStat
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"),
response_model=OmniTaskStatusResponse,
response_model=TaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
max_poll_attempts=160,
)
@ -752,7 +748,7 @@ class KlingTextToVideoNode(IO.ComfyNode):
IO.Combo.Input(
"mode",
options=modes,
default=modes[4],
default=modes[8],
tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.",
),
],
@ -834,7 +830,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=OmniTaskStatusResponse,
response_model=TaskStatusResponse,
data=OmniProText2VideoRequest(
model_name=model_name,
prompt=prompt,
@ -929,7 +925,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=OmniTaskStatusResponse,
response_model=TaskStatusResponse,
data=OmniProFirstLastFrameRequest(
model_name=model_name,
prompt=prompt,
@ -997,7 +993,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=OmniTaskStatusResponse,
response_model=TaskStatusResponse,
data=OmniProReferences2VideoRequest(
model_name=model_name,
prompt=prompt,
@ -1081,7 +1077,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=OmniTaskStatusResponse,
response_model=TaskStatusResponse,
data=OmniProReferences2VideoRequest(
model_name=model_name,
prompt=prompt,
@ -1162,7 +1158,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=OmniTaskStatusResponse,
response_model=TaskStatusResponse,
data=OmniProReferences2VideoRequest(
model_name=model_name,
prompt=prompt,
@ -1237,7 +1233,7 @@ class OmniProImageNode(IO.ComfyNode):
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/images/omni-image", method="POST"),
response_model=OmniTaskStatusResponse,
response_model=TaskStatusResponse,
data=OmniProImageRequest(
model_name=model_name,
prompt=prompt,
@ -1253,7 +1249,7 @@ class OmniProImageNode(IO.ComfyNode):
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/kling/v1/images/omni-image/{response.data.task_id}"),
response_model=OmniTaskStatusResponse,
response_model=TaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
)
return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.task_result.images[0].url))
@ -1328,9 +1324,8 @@ class KlingImage2VideoNode(IO.ComfyNode):
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingImage2VideoNode",
display_name="Kling Image to Video",
display_name="Kling Image(First Frame) to Video",
category="api node/video/Kling",
description="Kling Image to Video Node",
inputs=[
IO.Image.Input("start_frame", tooltip="The reference image used to generate the video."),
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
@ -1488,7 +1483,7 @@ class KlingStartEndFrameNode(IO.ComfyNode):
IO.Combo.Input(
"mode",
options=modes,
default=modes[8],
default=modes[6],
tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.",
),
],
@ -1951,7 +1946,7 @@ class KlingImageGenerationNode(IO.ComfyNode):
IO.Combo.Input(
"model_name",
options=[i.value for i in KlingImageGenModelName],
default="kling-v1",
default="kling-v2",
),
IO.Combo.Input(
"aspect_ratio",
@ -2034,6 +2029,136 @@ class KlingImageGenerationNode(IO.ComfyNode):
return IO.NodeOutput(await image_result_to_node_output(images))
class TextToVideoWithAudio(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingTextToVideoWithAudio",
display_name="Kling Text to Video with Audio",
category="api node/video/Kling",
inputs=[
IO.Combo.Input("model_name", options=["kling-v2-6"]),
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."),
IO.Combo.Input("mode", options=["pro"]),
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
IO.Combo.Input("duration", options=[5, 10]),
IO.Boolean.Input("generate_audio", default=True),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model_name: str,
prompt: str,
mode: str,
aspect_ratio: str,
duration: int,
generate_audio: bool,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2500)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/text2video", method="POST"),
response_model=TaskStatusResponse,
data=TextToVideoWithAudioRequest(
model_name=model_name,
prompt=prompt,
mode=mode,
aspect_ratio=aspect_ratio,
duration=str(duration),
sound="on" if generate_audio else "off",
),
)
if response.code:
raise RuntimeError(
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
)
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/text2video/{response.data.task_id}"),
response_model=TaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
)
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
class ImageToVideoWithAudio(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingImageToVideoWithAudio",
display_name="Kling Image(First Frame) to Video with Audio",
category="api node/video/Kling",
inputs=[
IO.Combo.Input("model_name", options=["kling-v2-6"]),
IO.Image.Input("start_frame"),
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."),
IO.Combo.Input("mode", options=["pro"]),
IO.Combo.Input("duration", options=[5, 10]),
IO.Boolean.Input("generate_audio", default=True),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model_name: str,
start_frame: Input.Image,
prompt: str,
mode: str,
duration: int,
generate_audio: bool,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2500)
validate_image_dimensions(start_frame, min_width=300, min_height=300)
validate_image_aspect_ratio(start_frame, (1, 2.5), (2.5, 1))
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/image2video", method="POST"),
response_model=TaskStatusResponse,
data=ImageToVideoWithAudioRequest(
model_name=model_name,
image=(await upload_images_to_comfyapi(cls, start_frame))[0],
prompt=prompt,
mode=mode,
duration=str(duration),
sound="on" if generate_audio else "off",
),
)
if response.code:
raise RuntimeError(
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
)
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"),
response_model=TaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
)
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
class KlingExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -2056,7 +2181,9 @@ class KlingExtension(ComfyExtension):
OmniProImageToVideoNode,
OmniProVideoToVideoNode,
OmniProEditVideoNode,
# OmniProImageNode, # need support from backend
OmniProImageNode,
TextToVideoWithAudio,
ImageToVideoWithAudio,
]

View File

@ -1,575 +0,0 @@
"""
Pika x ComfyUI API Nodes
Pika API docs: https://pika-827374fb.mintlify.app/api-reference
"""
from __future__ import annotations
from io import BytesIO
import logging
from typing import Optional
import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
from comfy_api_nodes.apis import pika_api as pika_defs
from comfy_api_nodes.util import (
validate_string,
download_url_to_video_output,
tensor_to_bytesio,
ApiEndpoint,
sync_op,
poll_op,
)
PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions"
PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps"
PATH_PIKAFFECTS = "/proxy/pika/generate/pikaffects"
PIKA_API_VERSION = "2.2"
PATH_TEXT_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/t2v"
PATH_IMAGE_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/i2v"
PATH_PIKAFRAMES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikaframes"
PATH_PIKASCENES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikascenes"
PATH_VIDEO_GET = "/proxy/pika/videos"
async def execute_task(
task_id: str,
cls: type[IO.ComfyNode],
) -> IO.NodeOutput:
final_response: pika_defs.PikaVideoResponse = await poll_op(
cls,
ApiEndpoint(path=f"{PATH_VIDEO_GET}/{task_id}"),
response_model=pika_defs.PikaVideoResponse,
status_extractor=lambda response: (response.status.value if response.status else None),
progress_extractor=lambda response: (response.progress if hasattr(response, "progress") else None),
estimated_duration=60,
max_poll_attempts=240,
)
if not final_response.url:
error_msg = f"Pika task {task_id} succeeded but no video data found in response:\n{final_response}"
logging.error(error_msg)
raise Exception(error_msg)
video_url = final_response.url
logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url)
return IO.NodeOutput(await download_url_to_video_output(video_url))
def get_base_inputs_types() -> list[IO.Input]:
"""Get the base required inputs types common to all Pika nodes."""
return [
IO.String.Input("prompt_text", multiline=True),
IO.String.Input("negative_prompt", multiline=True),
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
IO.Combo.Input("resolution", options=["1080p", "720p"], default="1080p"),
IO.Combo.Input("duration", options=[5, 10], default=5),
]
class PikaImageToVideo(IO.ComfyNode):
"""Pika 2.2 Image to Video Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PikaImageToVideoNode2_2",
display_name="Pika Image to Video",
description="Sends an image and prompt to the Pika API v2.2 to generate a video.",
category="api node/video/Pika",
inputs=[
IO.Image.Input("image", tooltip="The image to convert to video"),
*get_base_inputs_types(),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
image: torch.Tensor,
prompt_text: str,
negative_prompt: str,
seed: int,
resolution: str,
duration: int,
) -> IO.NodeOutput:
image_bytes_io = tensor_to_bytesio(image)
pika_files = {"image": ("image.png", image_bytes_io, "image/png")}
pika_request_data = pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
resolution=resolution,
duration=duration,
)
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikaTextToVideoNode(IO.ComfyNode):
"""Pika Text2Video v2.2 Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PikaTextToVideoNode2_2",
display_name="Pika Text to Video",
description="Sends a text prompt to the Pika API v2.2 to generate a video.",
category="api node/video/Pika",
inputs=[
*get_base_inputs_types(),
IO.Float.Input(
"aspect_ratio",
step=0.001,
min=0.4,
max=2.5,
default=1.7777777777777777,
tooltip="Aspect ratio (width / height)",
)
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
prompt_text: str,
negative_prompt: str,
seed: int,
resolution: str,
duration: int,
aspect_ratio: float,
) -> IO.NodeOutput:
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_TEXT_TO_VIDEO, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
resolution=resolution,
duration=duration,
aspectRatio=aspect_ratio,
),
content_type="application/x-www-form-urlencoded",
)
return await execute_task(initial_operation.video_id, cls)
class PikaScenes(IO.ComfyNode):
"""PikaScenes v2.2 Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PikaScenesV2_2",
display_name="Pika Scenes (Video Image Composition)",
description="Combine your images to create a video with the objects in them. Upload multiple images as ingredients and generate a high-quality video that incorporates all of them.",
category="api node/video/Pika",
inputs=[
*get_base_inputs_types(),
IO.Combo.Input(
"ingredients_mode",
options=["creative", "precise"],
default="creative",
),
IO.Float.Input(
"aspect_ratio",
step=0.001,
min=0.4,
max=2.5,
default=1.7777777777777777,
tooltip="Aspect ratio (width / height)",
),
IO.Image.Input(
"image_ingredient_1",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
IO.Image.Input(
"image_ingredient_2",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
IO.Image.Input(
"image_ingredient_3",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
IO.Image.Input(
"image_ingredient_4",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
IO.Image.Input(
"image_ingredient_5",
optional=True,
tooltip="Image that will be used as ingredient to create a video.",
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
prompt_text: str,
negative_prompt: str,
seed: int,
resolution: str,
duration: int,
ingredients_mode: str,
aspect_ratio: float,
image_ingredient_1: Optional[torch.Tensor] = None,
image_ingredient_2: Optional[torch.Tensor] = None,
image_ingredient_3: Optional[torch.Tensor] = None,
image_ingredient_4: Optional[torch.Tensor] = None,
image_ingredient_5: Optional[torch.Tensor] = None,
) -> IO.NodeOutput:
all_image_bytes_io = []
for image in [
image_ingredient_1,
image_ingredient_2,
image_ingredient_3,
image_ingredient_4,
image_ingredient_5,
]:
if image is not None:
all_image_bytes_io.append(tensor_to_bytesio(image))
pika_files = [
("images", (f"image_{i}.png", image_bytes_io, "image/png"))
for i, image_bytes_io in enumerate(all_image_bytes_io)
]
pika_request_data = pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost(
ingredientsMode=ingredients_mode,
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
resolution=resolution,
duration=duration,
aspectRatio=aspect_ratio,
)
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKASCENES, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikAdditionsNode(IO.ComfyNode):
"""Pika Pikadditions Node. Add an image into a video."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Pikadditions",
display_name="Pikadditions (Video Object Insertion)",
description="Add any object or image into your video. Upload a video and specify what you'd like to add to create a seamlessly integrated result.",
category="api node/video/Pika",
inputs=[
IO.Video.Input("video", tooltip="The video to add an image to."),
IO.Image.Input("image", tooltip="The image to add to the video."),
IO.String.Input("prompt_text", multiline=True),
IO.String.Input("negative_prompt", multiline=True),
IO.Int.Input(
"seed",
min=0,
max=0xFFFFFFFF,
control_after_generate=True,
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
video: VideoInput,
image: torch.Tensor,
prompt_text: str,
negative_prompt: str,
seed: int,
) -> IO.NodeOutput:
video_bytes_io = BytesIO()
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
video_bytes_io.seek(0)
image_bytes_io = tensor_to_bytesio(image)
pika_files = {
"video": ("video.mp4", video_bytes_io, "video/mp4"),
"image": ("image.png", image_bytes_io, "image/png"),
}
pika_request_data = pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
)
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKADDITIONS, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikaSwapsNode(IO.ComfyNode):
"""Pika Pikaswaps Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Pikaswaps",
display_name="Pika Swaps (Video Object Replacement)",
description="Swap out any object or region of your video with a new image or object. Define areas to replace either with a mask or coordinates.",
category="api node/video/Pika",
inputs=[
IO.Video.Input("video", tooltip="The video to swap an object in."),
IO.Image.Input(
"image",
tooltip="The image used to replace the masked object in the video.",
optional=True,
),
IO.Mask.Input(
"mask",
tooltip="Use the mask to define areas in the video to replace.",
optional=True,
),
IO.String.Input("prompt_text", multiline=True, optional=True),
IO.String.Input("negative_prompt", multiline=True, optional=True),
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True, optional=True),
IO.String.Input(
"region_to_modify",
multiline=True,
optional=True,
tooltip="Plaintext description of the object / region to modify.",
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
video: VideoInput,
image: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
prompt_text: str = "",
negative_prompt: str = "",
seed: int = 0,
region_to_modify: str = "",
) -> IO.NodeOutput:
video_bytes_io = BytesIO()
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
video_bytes_io.seek(0)
pika_files = {
"video": ("video.mp4", video_bytes_io, "video/mp4"),
}
if mask is not None:
pika_files["modifyRegionMask"] = ("mask.png", tensor_to_bytesio(mask), "image/png")
if image is not None:
pika_files["image"] = ("image.png", tensor_to_bytesio(image), "image/png")
pika_request_data = pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
modifyRegionRoi=region_to_modify if region_to_modify else None,
)
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKASWAPS, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_request_data,
files=pika_files,
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikaffectsNode(IO.ComfyNode):
"""Pika Pikaffects Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Pikaffects",
display_name="Pikaffects (Video Effects)",
description="Generate a video with a specific Pikaffect. Supported Pikaffects: Cake-ify, Crumble, Crush, Decapitate, Deflate, Dissolve, Explode, Eye-pop, Inflate, Levitate, Melt, Peel, Poke, Squish, Ta-da, Tear",
category="api node/video/Pika",
inputs=[
IO.Image.Input("image", tooltip="The reference image to apply the Pikaffect to."),
IO.Combo.Input(
"pikaffect", options=pika_defs.Pikaffect, default="Cake-ify"
),
IO.String.Input("prompt_text", multiline=True),
IO.String.Input("negative_prompt", multiline=True),
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
image: torch.Tensor,
pikaffect: str,
prompt_text: str,
negative_prompt: str,
seed: int,
) -> IO.NodeOutput:
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKAFFECTS, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
pikaffect=pikaffect,
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
),
files={"image": ("image.png", tensor_to_bytesio(image), "image/png")},
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikaStartEndFrameNode(IO.ComfyNode):
"""PikaFrames v2.2 Node."""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="PikaStartEndFrameNode2_2",
display_name="Pika Start and End Frame to Video",
description="Generate a video by combining your first and last frame. Upload two images to define the start and end points, and let the AI create a smooth transition between them.",
category="api node/video/Pika",
inputs=[
IO.Image.Input("image_start", tooltip="The first image to combine."),
IO.Image.Input("image_end", tooltip="The last image to combine."),
*get_base_inputs_types(),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
async def execute(
cls,
image_start: torch.Tensor,
image_end: torch.Tensor,
prompt_text: str,
negative_prompt: str,
seed: int,
resolution: str,
duration: int,
) -> IO.NodeOutput:
validate_string(prompt_text, field_name="prompt_text", min_length=1)
pika_files = [
("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")),
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
]
initial_operation = await sync_op(
cls,
ApiEndpoint(path=PATH_PIKAFRAMES, method="POST"),
response_model=pika_defs.PikaGenerateResponse,
data=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
promptText=prompt_text,
negativePrompt=negative_prompt,
seed=seed,
resolution=resolution,
duration=duration,
),
files=pika_files,
content_type="multipart/form-data",
)
return await execute_task(initial_operation.video_id, cls)
class PikaApiNodesExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
PikaImageToVideo,
PikaTextToVideoNode,
PikaScenes,
PikAdditionsNode,
PikaSwapsNode,
PikaffectsNode,
PikaStartEndFrameNode,
]
async def comfy_entrypoint() -> PikaApiNodesExtension:
return PikaApiNodesExtension()

View File

@ -102,8 +102,9 @@ class TripoTextToModelNode(IO.ComfyNode):
IO.Int.Input("model_seed", default=42, optional=True),
IO.Int.Input("texture_seed", default=42, optional=True),
IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True),
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True),
IO.Int.Input("face_limit", default=-1, min=-1, max=2000000, optional=True),
IO.Boolean.Input("quad", default=False, optional=True),
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
],
outputs=[
IO.String.Output(display_name="model_file"),
@ -131,6 +132,7 @@ class TripoTextToModelNode(IO.ComfyNode):
model_seed: Optional[int] = None,
texture_seed: Optional[int] = None,
texture_quality: Optional[str] = None,
geometry_quality: Optional[str] = None,
face_limit: Optional[int] = None,
quad: Optional[bool] = None,
) -> IO.NodeOutput:
@ -154,6 +156,7 @@ class TripoTextToModelNode(IO.ComfyNode):
texture_seed=texture_seed,
texture_quality=texture_quality,
face_limit=face_limit,
geometry_quality=geometry_quality,
auto_size=True,
quad=quad,
),
@ -194,6 +197,7 @@ class TripoImageToModelNode(IO.ComfyNode):
),
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True),
IO.Boolean.Input("quad", default=False, optional=True),
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
],
outputs=[
IO.String.Output(display_name="model_file"),
@ -220,6 +224,7 @@ class TripoImageToModelNode(IO.ComfyNode):
orientation=None,
texture_seed: Optional[int] = None,
texture_quality: Optional[str] = None,
geometry_quality: Optional[str] = None,
texture_alignment: Optional[str] = None,
face_limit: Optional[int] = None,
quad: Optional[bool] = None,
@ -246,6 +251,7 @@ class TripoImageToModelNode(IO.ComfyNode):
pbr=pbr,
model_seed=model_seed,
orientation=orientation,
geometry_quality=geometry_quality,
texture_alignment=texture_alignment,
texture_seed=texture_seed,
texture_quality=texture_quality,
@ -295,6 +301,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
),
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True),
IO.Boolean.Input("quad", default=False, optional=True),
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True),
],
outputs=[
IO.String.Output(display_name="model_file"),
@ -323,6 +330,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
model_seed: Optional[int] = None,
texture_seed: Optional[int] = None,
texture_quality: Optional[str] = None,
geometry_quality: Optional[str] = None,
texture_alignment: Optional[str] = None,
face_limit: Optional[int] = None,
quad: Optional[bool] = None,
@ -359,6 +367,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
model_seed=model_seed,
texture_seed=texture_seed,
texture_quality=texture_quality,
geometry_quality=geometry_quality,
texture_alignment=texture_alignment,
face_limit=face_limit,
quad=quad,
@ -508,6 +517,8 @@ class TripoRetargetNode(IO.ComfyNode):
options=[
"preset:idle",
"preset:walk",
"preset:run",
"preset:dive",
"preset:climb",
"preset:jump",
"preset:slash",
@ -515,6 +526,11 @@ class TripoRetargetNode(IO.ComfyNode):
"preset:hurt",
"preset:fall",
"preset:turn",
"preset:quadruped:walk",
"preset:hexapod:walk",
"preset:octopod:walk",
"preset:serpentine:march",
"preset:aquatic:march"
],
),
],
@ -563,7 +579,7 @@ class TripoConversionNode(IO.ComfyNode):
"face_limit",
default=-1,
min=-1,
max=500000,
max=2000000,
optional=True,
),
IO.Int.Input(
@ -579,6 +595,40 @@ class TripoConversionNode(IO.ComfyNode):
default="JPEG",
optional=True,
),
IO.Boolean.Input("force_symmetry", default=False, optional=True),
IO.Boolean.Input("flatten_bottom", default=False, optional=True),
IO.Float.Input(
"flatten_bottom_threshold",
default=0.0,
min=0.0,
max=1.0,
optional=True,
),
IO.Boolean.Input("pivot_to_center_bottom", default=False, optional=True),
IO.Float.Input(
"scale_factor",
default=1.0,
min=0.0,
optional=True,
),
IO.Boolean.Input("with_animation", default=False, optional=True),
IO.Boolean.Input("pack_uv", default=False, optional=True),
IO.Boolean.Input("bake", default=False, optional=True),
IO.String.Input("part_names", default="", optional=True), # comma-separated list
IO.Combo.Input(
"fbx_preset",
options=["blender", "mixamo", "3dsmax"],
default="blender",
optional=True,
),
IO.Boolean.Input("export_vertex_colors", default=False, optional=True),
IO.Combo.Input(
"export_orientation",
options=["align_image", "default"],
default="default",
optional=True,
),
IO.Boolean.Input("animate_in_place", default=False, optional=True),
],
outputs=[],
hidden=[
@ -604,12 +654,31 @@ class TripoConversionNode(IO.ComfyNode):
original_model_task_id,
format: str,
quad: bool,
force_symmetry: bool,
face_limit: int,
flatten_bottom: bool,
flatten_bottom_threshold: float,
texture_size: int,
texture_format: str,
pivot_to_center_bottom: bool,
scale_factor: float,
with_animation: bool,
pack_uv: bool,
bake: bool,
part_names: str,
fbx_preset: str,
export_vertex_colors: bool,
export_orientation: str,
animate_in_place: bool,
) -> IO.NodeOutput:
if not original_model_task_id:
raise RuntimeError("original_model_task_id is required")
# Parse part_names from comma-separated string to list
part_names_list = None
if part_names and part_names.strip():
part_names_list = [name.strip() for name in part_names.split(',') if name.strip()]
response = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"),
@ -618,9 +687,22 @@ class TripoConversionNode(IO.ComfyNode):
original_model_task_id=original_model_task_id,
format=format,
quad=quad if quad else None,
force_symmetry=force_symmetry if force_symmetry else None,
face_limit=face_limit if face_limit != -1 else None,
flatten_bottom=flatten_bottom if flatten_bottom else None,
flatten_bottom_threshold=flatten_bottom_threshold if flatten_bottom_threshold != 0.0 else None,
texture_size=texture_size if texture_size != 4096 else None,
texture_format=texture_format if texture_format != "JPEG" else None,
pivot_to_center_bottom=pivot_to_center_bottom if pivot_to_center_bottom else None,
scale_factor=scale_factor if scale_factor != 1.0 else None,
with_animation=with_animation if with_animation else None,
pack_uv=pack_uv if pack_uv else None,
bake=bake if bake else None,
part_names=part_names_list,
fbx_preset=fbx_preset if fbx_preset != "blender" else None,
export_vertex_colors=export_vertex_colors if export_vertex_colors else None,
export_orientation=export_orientation if export_orientation != "default" else None,
animate_in_place=animate_in_place if animate_in_place else None,
),
)
return await poll_until_finished(cls, response, average_duration=30)

View File

@ -659,6 +659,31 @@ class SamplerSASolver(io.ComfyNode):
get_sampler = execute
class SamplerSEEDS2(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SamplerSEEDS2",
category="sampling/custom_sampling/samplers",
inputs=[
io.Combo.Input("solver_type", options=["phi_1", "phi_2"]),
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength"),
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="SDE noise multiplier"),
io.Float.Input("r", default=0.5, min=0.01, max=1.0, step=0.01, round=False, tooltip="Relative step size for the intermediate stage (c2 node)"),
],
outputs=[io.Sampler.Output()]
)
@classmethod
def execute(cls, solver_type, eta, s_noise, r) -> io.NodeOutput:
sampler_name = "seeds_2"
sampler = comfy.samplers.ksampler(
sampler_name,
{"eta": eta, "s_noise": s_noise, "r": r, "solver_type": solver_type},
)
return io.NodeOutput(sampler)
class Noise_EmptyNoise:
def __init__(self):
self.seed = 0
@ -996,6 +1021,7 @@ class CustomSamplersExtension(ComfyExtension):
SamplerDPMAdaptative,
SamplerER_SDE,
SamplerSASolver,
SamplerSEEDS2,
SplitSigmas,
SplitSigmasDenoise,
FlipSigmas,

View File

@ -154,12 +154,13 @@ class FluxKontextMultiReferenceLatentMethod(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="FluxKontextMultiReferenceLatentMethod",
display_name="Edit Model Reference Method",
category="advanced/conditioning/flux",
inputs=[
io.Conditioning.Input("conditioning"),
io.Combo.Input(
"reference_latents_method",
options=["offset", "index", "uxo/uno"],
options=["offset", "index", "uxo/uno", "index_timestep_zero"],
),
],
outputs=[

View File

@ -243,7 +243,16 @@ class ModelPatchLoader:
model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
elif 'control_all_x_embedder.2-1.weight' in sd: # alipai z image fun controlnet
sd = z_image_convert(sd)
model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
config = {}
if 'control_layers.14.adaLN_modulation.0.weight' in sd:
config['n_control_layers'] = 15
config['additional_in_dim'] = 17
config['refiner_control'] = True
ref_weight = sd.get("control_noise_refiner.0.after_proj.weight", None)
if ref_weight is not None:
if torch.count_nonzero(ref_weight) == 0:
config['broken'] = True
model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast, **config)
model.load_state_dict(sd)
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
@ -297,51 +306,110 @@ class DiffSynthCnetPatch:
return [self.model_patch]
class ZImageControlPatch:
def __init__(self, model_patch, vae, image, strength):
def __init__(self, model_patch, vae, image, strength, inpaint_image=None, mask=None):
self.model_patch = model_patch
self.vae = vae
self.image = image
self.inpaint_image = inpaint_image
self.mask = mask
self.strength = strength
self.encoded_image = self.encode_latent_cond(image)
self.encoded_image_size = (image.shape[1], image.shape[2])
self.is_inpaint = self.model_patch.model.additional_in_dim > 0
skip_encoding = False
if self.image is not None and self.inpaint_image is not None:
if self.image.shape != self.inpaint_image.shape:
skip_encoding = True
if skip_encoding:
self.encoded_image = None
else:
self.encoded_image = self.encode_latent_cond(self.image, self.inpaint_image)
if self.image is None:
self.encoded_image_size = (self.inpaint_image.shape[1], self.inpaint_image.shape[2])
else:
self.encoded_image_size = (self.image.shape[1], self.image.shape[2])
self.temp_data = None
def encode_latent_cond(self, image):
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(image))
def encode_latent_cond(self, control_image=None, inpaint_image=None):
latent_image = None
if control_image is not None:
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(control_image))
if self.is_inpaint:
if inpaint_image is None:
inpaint_image = torch.ones_like(control_image) * 0.5
if self.mask is not None:
mask_inpaint = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True), inpaint_image.shape[-2], inpaint_image.shape[-3], "bilinear", "center")
inpaint_image = ((inpaint_image - 0.5) * mask_inpaint.movedim(1, -1).round()) + 0.5
inpaint_image_latent = comfy.latent_formats.Flux().process_in(self.vae.encode(inpaint_image))
if self.mask is None:
mask_ = torch.zeros_like(inpaint_image_latent)[:, :1]
else:
mask_ = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True), inpaint_image_latent.shape[-1], inpaint_image_latent.shape[-2], "nearest", "center")
if latent_image is None:
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(torch.ones_like(inpaint_image) * 0.5))
return torch.cat([latent_image, mask_, inpaint_image_latent], dim=1)
else:
return latent_image
def __call__(self, kwargs):
x = kwargs.get("x")
img = kwargs.get("img")
img_input = kwargs.get("img_input")
txt = kwargs.get("txt")
pe = kwargs.get("pe")
vec = kwargs.get("vec")
block_index = kwargs.get("block_index")
block_type = kwargs.get("block_type", "")
spacial_compression = self.vae.spacial_compression_encode()
if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression):
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
image_scaled = None
if self.image is not None:
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center").movedim(1, -1)
self.encoded_image_size = (image_scaled.shape[-3], image_scaled.shape[-2])
inpaint_scaled = None
if self.inpaint_image is not None:
inpaint_scaled = comfy.utils.common_upscale(self.inpaint_image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center").movedim(1, -1)
self.encoded_image_size = (inpaint_scaled.shape[-3], inpaint_scaled.shape[-2])
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
self.encoded_image = self.encode_latent_cond(image_scaled.movedim(1, -1))
self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1])
self.encoded_image = self.encode_latent_cond(image_scaled, inpaint_scaled)
comfy.model_management.load_models_gpu(loaded_models)
cnet_index = (block_index // 5)
cnet_index_float = (block_index / 5)
cnet_blocks = self.model_patch.model.n_control_layers
div = round(30 / cnet_blocks)
cnet_index = (block_index // div)
cnet_index_float = (block_index / div)
kwargs.pop("img") # we do ops in place
kwargs.pop("txt")
cnet_blocks = self.model_patch.model.n_control_layers
if cnet_index_float > (cnet_blocks - 1):
self.temp_data = None
return kwargs
if self.temp_data is None or self.temp_data[0] > cnet_index:
if block_type == "noise_refiner":
self.temp_data = (-3, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
else:
self.temp_data = (-1, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
if block_type == "noise_refiner":
next_layer = self.temp_data[0] + 1
self.temp_data = (next_layer, self.model_patch.model.forward_noise_refiner_block(block_index, self.temp_data[1][1], img_input[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
if self.temp_data[1][0] is not None:
img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
else:
while self.temp_data[0] < cnet_index and (self.temp_data[0] + 1) < cnet_blocks:
next_layer = self.temp_data[0] + 1
self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img_input[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
if cnet_index_float == self.temp_data[0]:
img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
@ -352,6 +420,7 @@ class ZImageControlPatch:
def to(self, device_or_dtype):
if isinstance(device_or_dtype, torch.device):
if self.encoded_image is not None:
self.encoded_image = self.encoded_image.to(device_or_dtype)
self.temp_data = None
return self
@ -375,9 +444,12 @@ class QwenImageDiffsynthControlnet:
CATEGORY = "advanced/loaders/qwen"
def diffsynth_controlnet(self, model, model_patch, vae, image, strength, mask=None):
def diffsynth_controlnet(self, model, model_patch, vae, image=None, strength=1.0, inpaint_image=None, mask=None):
model_patched = model.clone()
if image is not None:
image = image[:, :, :, :3]
if inpaint_image is not None:
inpaint_image = inpaint_image[:, :, :, :3]
if mask is not None:
if mask.ndim == 3:
mask = mask.unsqueeze(1)
@ -386,11 +458,24 @@ class QwenImageDiffsynthControlnet:
mask = 1.0 - mask
if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control):
model_patched.set_model_double_block_patch(ZImageControlPatch(model_patch, vae, image, strength))
patch = ZImageControlPatch(model_patch, vae, image, strength, inpaint_image=inpaint_image, mask=mask)
model_patched.set_model_noise_refiner_patch(patch)
model_patched.set_model_double_block_patch(patch)
else:
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
return (model_patched,)
class ZImageFunControlnet(QwenImageDiffsynthControlnet):
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"model_patch": ("MODEL_PATCH",),
"vae": ("VAE",),
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
},
"optional": {"image": ("IMAGE",), "inpaint_image": ("IMAGE",), "mask": ("MASK",)}}
CATEGORY = "advanced/loaders/zimage"
class UsoStyleProjectorPatch:
def __init__(self, model_patch, encoded_image):
@ -438,5 +523,6 @@ class USOStyleReference:
NODE_CLASS_MAPPINGS = {
"ModelPatchLoader": ModelPatchLoader,
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
"ZImageFunControlnet": ZImageFunControlnet,
"USOStyleReference": USOStyleReference,
}

View File

@ -2,6 +2,8 @@ from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from comfy_api.torch_helpers import set_torch_compile_wrapper
def skip_torch_compile_dict(guard_entries):
return [("transformer_options" not in entry.name) for entry in guard_entries]
class TorchCompileModel(io.ComfyNode):
@classmethod
@ -23,7 +25,7 @@ class TorchCompileModel(io.ComfyNode):
@classmethod
def execute(cls, model, backend) -> io.NodeOutput:
m = model.clone()
set_torch_compile_wrapper(model=m, backend=backend)
set_torch_compile_wrapper(model=m, backend=backend, options={"guard_filter_fn": skip_torch_compile_dict})
return io.NodeOutput(m)

View File

@ -0,0 +1,535 @@
import nodes
import node_helpers
import torch
import torchvision.transforms.functional as TF
import comfy.model_management
import comfy.utils
import numpy as np
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from comfy_extras.nodes_wan import parse_json_tracks
# https://github.com/ali-vilab/Wan-Move/blob/main/wan/modules/trajectory.py
from PIL import Image, ImageDraw
SKIP_ZERO = False
def get_pos_emb(
pos_k: torch.Tensor, # A 1D tensor containing positions for which to generate embeddings.
pos_emb_dim: int,
theta_func: callable = lambda i, d: torch.pow(10000, torch.mul(2, torch.div(i.to(torch.float32), d))), #Function to compute thetas based on position and embedding dimensions.
device: torch.device = torch.device("cpu"),
dtype: torch.dtype = torch.float32,
) -> torch.Tensor: # The position embeddings (batch_size, pos_emb_dim)
assert pos_emb_dim % 2 == 0, "The dimension of position embeddings must be even."
pos_k = pos_k.to(device, dtype)
if SKIP_ZERO:
pos_k = pos_k + 1
batch_size = pos_k.size(0)
denominator = torch.arange(0, pos_emb_dim // 2, device=device, dtype=dtype)
# Expand denominator to match the shape needed for broadcasting
denominator_expanded = denominator.view(1, -1).expand(batch_size, -1)
thetas = theta_func(denominator_expanded, pos_emb_dim)
# Ensure pos_k is in the correct shape for broadcasting
pos_k_expanded = pos_k.view(-1, 1).to(dtype)
sin_thetas = torch.sin(torch.div(pos_k_expanded, thetas))
cos_thetas = torch.cos(torch.div(pos_k_expanded, thetas))
# Concatenate sine and cosine embeddings along the last dimension
pos_emb = torch.cat([sin_thetas, cos_thetas], dim=-1)
return pos_emb
def create_pos_embeddings(
pred_tracks: torch.Tensor, # the predicted tracks, [T, N, 2]
pred_visibility: torch.Tensor, # the predicted visibility [T, N]
downsample_ratios: list[int], # the ratios for downsampling time, height, and width
height: int, # the height of the feature map
width: int, # the width of the feature map
track_num: int = -1, # the number of tracks to use
t_down_strategy: str = "sample", # the strategy for downsampling time dimension
):
assert t_down_strategy in ["sample", "average"], "Invalid strategy for downsampling time dimension."
t, n, _ = pred_tracks.shape
t_down, h_down, w_down = downsample_ratios
track_pos = - torch.ones(n, (t-1) // t_down + 1, 2, dtype=torch.long)
if track_num == -1:
track_num = n
tracks_idx = torch.randperm(n)[:track_num]
tracks = pred_tracks[:, tracks_idx]
visibility = pred_visibility[:, tracks_idx]
for t_idx in range(0, t, t_down):
if t_down_strategy == "sample" or t_idx == 0:
cur_tracks = tracks[t_idx] # [N, 2]
cur_visibility = visibility[t_idx] # [N]
else:
cur_tracks = tracks[t_idx:t_idx+t_down].mean(dim=0)
cur_visibility = torch.any(visibility[t_idx:t_idx+t_down], dim=0)
for i in range(track_num):
if not cur_visibility[i] or cur_tracks[i][0] < 0 or cur_tracks[i][1] < 0 or cur_tracks[i][0] >= width or cur_tracks[i][1] >= height:
continue
x, y = cur_tracks[i]
x, y = int(x // w_down), int(y // h_down)
track_pos[i, t_idx // t_down, 0], track_pos[i, t_idx // t_down, 1] = y, x
return track_pos # the position embeddings, [N, T', 2], 2 = height, width
def replace_feature(
vae_feature: torch.Tensor, # [B, C', T', H', W']
track_pos: torch.Tensor, # [B, N, T', 2]
strength: float = 1.0
) -> torch.Tensor:
b, _, t, h, w = vae_feature.shape
assert b == track_pos.shape[0], "Batch size mismatch."
n = track_pos.shape[1]
# Shuffle the trajectory order
track_pos = track_pos[:, torch.randperm(n), :, :]
# Extract coordinates at time steps ≥ 1 and generate a valid mask
current_pos = track_pos[:, :, 1:, :] # [B, N, T-1, 2]
mask = (current_pos[..., 0] >= 0) & (current_pos[..., 1] >= 0) # [B, N, T-1]
# Get all valid indices
valid_indices = mask.nonzero(as_tuple=False) # [num_valid, 3]
num_valid = valid_indices.shape[0]
if num_valid == 0:
return vae_feature
# Decompose valid indices into each dimension
batch_idx = valid_indices[:, 0]
track_idx = valid_indices[:, 1]
t_rel = valid_indices[:, 2]
t_target = t_rel + 1 # Convert to original time step indices
# Extract target position coordinates
h_target = current_pos[batch_idx, track_idx, t_rel, 0].long() # Ensure integer indices
w_target = current_pos[batch_idx, track_idx, t_rel, 1].long()
# Extract source position coordinates (t=0)
h_source = track_pos[batch_idx, track_idx, 0, 0].long()
w_source = track_pos[batch_idx, track_idx, 0, 1].long()
# Get source features and assign to target positions
src_features = vae_feature[batch_idx, :, 0, h_source, w_source]
dst_features = vae_feature[batch_idx, :, t_target, h_target, w_target]
vae_feature[batch_idx, :, t_target, h_target, w_target] = dst_features + (src_features - dst_features) * strength
return vae_feature
# Visualize functions
def _draw_gradient_polyline_on_overlay(overlay, line_width, points, start_color, opacity=1.0):
draw = ImageDraw.Draw(overlay, 'RGBA')
points = points[::-1]
# Compute total length
total_length = 0
segment_lengths = []
for i in range(len(points) - 1):
dx = points[i + 1][0] - points[i][0]
dy = points[i + 1][1] - points[i][1]
length = (dx * dx + dy * dy) ** 0.5
segment_lengths.append(length)
total_length += length
if total_length == 0:
return
accumulated_length = 0
# Draw the gradient polyline
for idx, (start_point, end_point) in enumerate(zip(points[:-1], points[1:])):
segment_length = segment_lengths[idx]
steps = max(int(segment_length), 1)
for i in range(steps):
current_length = accumulated_length + (i / steps) * segment_length
ratio = current_length / total_length
alpha = int(255 * (1 - ratio) * opacity)
color = (*start_color, alpha)
x = int(start_point[0] + (end_point[0] - start_point[0]) * i / steps)
y = int(start_point[1] + (end_point[1] - start_point[1]) * i / steps)
dynamic_line_width = max(int(line_width * (1 - ratio)), 1)
draw.line([(x, y), (x + 1, y)], fill=color, width=dynamic_line_width)
accumulated_length += segment_length
def add_weighted(rgb, track):
rgb = np.array(rgb) # [H, W, C] "RGB"
track = np.array(track) # [H, W, C] "RGBA"
alpha = track[:, :, 3] / 255.0
alpha = np.stack([alpha] * 3, axis=-1)
blend_img = track[:, :, :3] * alpha + rgb * (1 - alpha)
return Image.fromarray(blend_img.astype(np.uint8))
def draw_tracks_on_video(video, tracks, visibility=None, track_frame=24, circle_size=12, opacity=0.5, line_width=16):
color_map = [(102, 153, 255), (0, 255, 255), (255, 255, 0), (255, 102, 204), (0, 255, 0)]
video = video.byte().cpu().numpy() # (81, 480, 832, 3)
tracks = tracks[0].long().detach().cpu().numpy()
if visibility is not None:
visibility = visibility[0].detach().cpu().numpy()
num_frames, height, width = video.shape[:3]
num_tracks = tracks.shape[1]
alpha_opacity = int(255 * opacity)
output_frames = []
for t in range(num_frames):
frame_rgb = video[t].astype(np.float32)
# Create a single RGBA overlay for all tracks in this frame
overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0))
draw_overlay = ImageDraw.Draw(overlay)
polyline_data = []
# Draw all circles on a single overlay
for n in range(num_tracks):
if visibility is not None and visibility[t, n] == 0:
continue
track_coord = tracks[t, n]
color = color_map[n % len(color_map)]
circle_color = color + (alpha_opacity,)
draw_overlay.ellipse((track_coord[0] - circle_size, track_coord[1] - circle_size, track_coord[0] + circle_size, track_coord[1] + circle_size),
fill=circle_color
)
# Store polyline data for batch processing
tracks_coord = tracks[max(t - track_frame, 0):t + 1, n]
if len(tracks_coord) > 1:
polyline_data.append((tracks_coord, color))
# Blend circles overlay once
overlay_np = np.array(overlay)
alpha = overlay_np[:, :, 3:4] / 255.0
frame_rgb = overlay_np[:, :, :3] * alpha + frame_rgb * (1 - alpha)
# Draw all polylines on a single overlay
if polyline_data:
polyline_overlay = Image.new("RGBA", (width, height), (0, 0, 0, 0))
for tracks_coord, color in polyline_data:
_draw_gradient_polyline_on_overlay(polyline_overlay, line_width, tracks_coord, color, opacity)
# Blend polylines overlay once
polyline_np = np.array(polyline_overlay)
alpha = polyline_np[:, :, 3:4] / 255.0
frame_rgb = polyline_np[:, :, :3] * alpha + frame_rgb * (1 - alpha)
output_frames.append(Image.fromarray(frame_rgb.astype(np.uint8)))
return output_frames
class WanMoveVisualizeTracks(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanMoveVisualizeTracks",
category="conditioning/video_models",
inputs=[
io.Image.Input("images"),
io.Tracks.Input("tracks", optional=True),
io.Int.Input("line_resolution", default=24, min=1, max=1024),
io.Int.Input("circle_size", default=12, min=1, max=128),
io.Float.Input("opacity", default=0.75, min=0.0, max=1.0, step=0.01),
io.Int.Input("line_width", default=16, min=1, max=128),
],
outputs=[
io.Image.Output(),
],
)
@classmethod
def execute(cls, images, line_resolution, circle_size, opacity, line_width, tracks=None) -> io.NodeOutput:
if tracks is None:
return io.NodeOutput(images)
track_path = tracks["track_path"].unsqueeze(0)
track_visibility = tracks["track_visibility"].unsqueeze(0)
images_in = images * 255.0
if images_in.shape[0] != track_path.shape[1]:
repeat_count = track_path.shape[1] // images.shape[0]
images_in = images_in.repeat(repeat_count, 1, 1, 1)
track_video = draw_tracks_on_video(images_in, track_path, track_visibility, track_frame=line_resolution, circle_size=circle_size, opacity=opacity, line_width=line_width)
track_video = torch.stack([TF.to_tensor(frame) for frame in track_video], dim=0).movedim(1, -1).float()
return io.NodeOutput(track_video.to(comfy.model_management.intermediate_device()))
class WanMoveTracksFromCoords(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanMoveTracksFromCoords",
category="conditioning/video_models",
inputs=[
io.String.Input("track_coords", force_input=True, default="[]", optional=True),
io.Mask.Input("track_mask", optional=True),
],
outputs=[
io.Tracks.Output(),
io.Int.Output(display_name="track_length"),
],
)
@classmethod
def execute(cls, track_coords, track_mask=None) -> io.NodeOutput:
device=comfy.model_management.intermediate_device()
tracks_data = parse_json_tracks(track_coords)
track_length = len(tracks_data[0])
track_list = [
[[track[frame]['x'], track[frame]['y']] for track in tracks_data]
for frame in range(len(tracks_data[0]))
]
tracks = torch.tensor(track_list, dtype=torch.float32, device=device) # [frames, num_tracks, 2]
num_tracks = tracks.shape[-2]
if track_mask is None:
track_visibility = torch.ones((track_length, num_tracks), dtype=torch.bool, device=device)
else:
track_visibility = (track_mask > 0).any(dim=(1, 2)).unsqueeze(-1)
out_track_info = {}
out_track_info["track_path"] = tracks
out_track_info["track_visibility"] = track_visibility
return io.NodeOutput(out_track_info, track_length)
class GenerateTracks(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="GenerateTracks",
category="conditioning/video_models",
inputs=[
io.Int.Input("width", default=832, min=16, max=4096, step=16),
io.Int.Input("height", default=480, min=16, max=4096, step=16),
io.Float.Input("start_x", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized X coordinate (0-1) for start position."),
io.Float.Input("start_y", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized Y coordinate (0-1) for start position."),
io.Float.Input("end_x", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized X coordinate (0-1) for end position."),
io.Float.Input("end_y", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="Normalized Y coordinate (0-1) for end position."),
io.Int.Input("num_frames", default=81, min=1, max=1024),
io.Int.Input("num_tracks", default=5, min=1, max=100),
io.Float.Input("track_spread", default=0.025, min=0.0, max=1.0, step=0.001, tooltip="Normalized distance between tracks. Tracks are spread perpendicular to the motion direction."),
io.Boolean.Input("bezier", default=False, tooltip="Enable Bezier curve path using the mid point as control point."),
io.Float.Input("mid_x", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Normalized X control point for Bezier curve. Only used when 'bezier' is enabled."),
io.Float.Input("mid_y", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Normalized Y control point for Bezier curve. Only used when 'bezier' is enabled."),
io.Combo.Input(
"interpolation",
options=["linear", "ease_in", "ease_out", "ease_in_out", "constant"],
tooltip="Controls the timing/speed of movement along the path.",
),
io.Mask.Input("track_mask", optional=True, tooltip="Optional mask to indicate visible frames."),
],
outputs=[
io.Tracks.Output(),
io.Int.Output(display_name="track_length"),
],
)
@classmethod
def execute(cls, width, height, start_x, start_y, mid_x, mid_y, end_x, end_y, num_frames, num_tracks,
track_spread, bezier=False, interpolation="linear", track_mask=None) -> io.NodeOutput:
device = comfy.model_management.intermediate_device()
track_length = num_frames
# normalized coordinates to pixel coordinates
start_x_px = start_x * width
start_y_px = start_y * height
mid_x_px = mid_x * width
mid_y_px = mid_y * height
end_x_px = end_x * width
end_y_px = end_y * height
track_spread_px = track_spread * (width + height) / 2 # Use average of width/height for spread to keep it proportional
t = torch.linspace(0, 1, num_frames, device=device)
if interpolation == "constant": # All points stay at start position
interp_values = torch.zeros_like(t)
elif interpolation == "linear":
interp_values = t
elif interpolation == "ease_in":
interp_values = t ** 2
elif interpolation == "ease_out":
interp_values = 1 - (1 - t) ** 2
elif interpolation == "ease_in_out":
interp_values = t * t * (3 - 2 * t)
if bezier: # apply interpolation to t for timing control along the bezier path
t_interp = interp_values
one_minus_t = 1 - t_interp
x_positions = one_minus_t ** 2 * start_x_px + 2 * one_minus_t * t_interp * mid_x_px + t_interp ** 2 * end_x_px
y_positions = one_minus_t ** 2 * start_y_px + 2 * one_minus_t * t_interp * mid_y_px + t_interp ** 2 * end_y_px
tangent_x = 2 * one_minus_t * (mid_x_px - start_x_px) + 2 * t_interp * (end_x_px - mid_x_px)
tangent_y = 2 * one_minus_t * (mid_y_px - start_y_px) + 2 * t_interp * (end_y_px - mid_y_px)
else: # calculate base x and y positions for each frame (center track)
x_positions = start_x_px + (end_x_px - start_x_px) * interp_values
y_positions = start_y_px + (end_y_px - start_y_px) * interp_values
# For non-bezier, tangent is constant (direction from start to end)
tangent_x = torch.full_like(t, end_x_px - start_x_px)
tangent_y = torch.full_like(t, end_y_px - start_y_px)
track_list = []
for frame_idx in range(num_frames):
# Calculate perpendicular direction at this frame
tx = tangent_x[frame_idx].item()
ty = tangent_y[frame_idx].item()
length = (tx ** 2 + ty ** 2) ** 0.5
if length > 0: # Perpendicular unit vector (rotate 90 degrees)
perp_x = -ty / length
perp_y = tx / length
else: # If tangent is zero, spread horizontally
perp_x = 1.0
perp_y = 0.0
frame_tracks = []
for track_idx in range(num_tracks): # center tracks around the main path offset ranges from -(num_tracks-1)/2 to +(num_tracks-1)/2
offset = (track_idx - (num_tracks - 1) / 2) * track_spread_px
track_x = x_positions[frame_idx].item() + perp_x * offset
track_y = y_positions[frame_idx].item() + perp_y * offset
frame_tracks.append([track_x, track_y])
track_list.append(frame_tracks)
tracks = torch.tensor(track_list, dtype=torch.float32, device=device) # [frames, num_tracks, 2]
if track_mask is None:
track_visibility = torch.ones((track_length, num_tracks), dtype=torch.bool, device=device)
else:
track_visibility = (track_mask > 0).any(dim=(1, 2)).unsqueeze(-1)
out_track_info = {}
out_track_info["track_path"] = tracks
out_track_info["track_visibility"] = track_visibility
return io.NodeOutput(out_track_info, track_length)
class WanMoveConcatTrack(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanMoveConcatTrack",
category="conditioning/video_models",
inputs=[
io.Tracks.Input("tracks_1"),
io.Tracks.Input("tracks_2", optional=True),
],
outputs=[
io.Tracks.Output(),
],
)
@classmethod
def execute(cls, tracks_1=None, tracks_2=None) -> io.NodeOutput:
if tracks_2 is None:
return io.NodeOutput(tracks_1)
tracks_out = torch.cat([tracks_1["track_path"], tracks_2["track_path"]], dim=1) # Concatenate along the track dimension
mask_out = torch.cat([tracks_1["track_visibility"], tracks_2["track_visibility"]], dim=-1)
out_track_info = {}
out_track_info["track_path"] = tracks_out
out_track_info["track_visibility"] = mask_out
return io.NodeOutput(out_track_info)
class WanMoveTrackToVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanMoveTrackToVideo",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Tracks.Input("tracks", optional=True),
io.Float.Input("strength", default=1.0, min=0.0, max=100.0, step=0.01, tooltip="Strength of the track conditioning."),
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Image.Input("start_image"),
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, strength, tracks=None, start_image=None, clip_vision_output=None) -> io.NodeOutput:
device=comfy.model_management.intermediate_device()
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=device)
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
image = torch.ones((length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5
image[:start_image.shape[0]] = start_image
concat_latent_image = vae.encode(image[:, :, :, :3])
mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
if tracks is not None and strength > 0.0:
tracks_path = tracks["track_path"][:length] # [T, N, 2]
num_tracks = tracks_path.shape[-2]
track_visibility = tracks.get("track_visibility", torch.ones((length, num_tracks), dtype=torch.bool, device=device))
track_pos = create_pos_embeddings(tracks_path, track_visibility, [4, 8, 8], height, width, track_num=num_tracks)
track_pos = comfy.utils.resize_to_batch_size(track_pos.unsqueeze(0), batch_size)
concat_latent_image_pos = replace_feature(concat_latent_image, track_pos, strength)
else:
concat_latent_image_pos = concat_latent_image
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image_pos, "concat_mask": mask})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
out_latent = {}
out_latent["samples"] = latent
return io.NodeOutput(positive, negative, out_latent)
class WanMoveExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
WanMoveTrackToVideo,
WanMoveTracksFromCoords,
WanMoveConcatTrack,
WanMoveVisualizeTracks,
GenerateTracks,
]
async def comfy_entrypoint() -> WanMoveExtension:
return WanMoveExtension()

View File

@ -13,6 +13,7 @@ import asyncio
import torch
import comfy.model_management
from latent_preview import set_preview_method
import nodes
from comfy_execution.caching import (
BasicCache,
@ -669,6 +670,8 @@ class PromptExecutor:
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
set_preview_method(extra_data.get("preview_method"))
nodes.interrupt_processing(False)
if "client_id" in extra_data:

View File

@ -8,6 +8,8 @@ import folder_paths
import comfy.utils
import logging
default_preview_method = args.preview_method
MAX_PREVIEW_RESOLUTION = args.preview_size
VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"]
@ -125,3 +127,11 @@ def prepare_callback(model, steps, x0_output_dict=None):
pbar.update_absolute(step + 1, total_steps, preview_bytes)
return callback
def set_preview_method(override: str = None):
if override and override != "default":
method = LatentPreviewMethod.from_string(override)
if method is not None:
args.preview_method = method
return
args.preview_method = default_preview_method

View File

@ -1 +1 @@
comfyui_manager==4.0.3b4
comfyui_manager==4.0.3b5

View File

@ -2358,6 +2358,7 @@ async def init_builtin_extra_nodes():
"nodes_logic.py",
"nodes_nop.py",
"nodes_kandinsky5.py",
"nodes_wanmove.py",
]
import_failed = []
@ -2383,7 +2384,6 @@ async def init_builtin_api_nodes():
"nodes_recraft.py",
"nodes_pixverse.py",
"nodes_stability.py",
"nodes_pika.py",
"nodes_runway.py",
"nodes_sora.py",
"nodes_topaz.py",

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.33.13
comfyui-workflow-templates==0.7.54
comfyui-frontend-package==1.34.9
comfyui-workflow-templates==0.7.59
comfyui-embedded-docs==0.3.1
torch
torchsde

View File

@ -0,0 +1,352 @@
"""
Unit tests for Queue-specific Preview Method Override feature.
Tests the preview method override functionality:
- LatentPreviewMethod.from_string() method
- set_preview_method() function in latent_preview.py
- default_preview_method variable
- Integration with args.preview_method
"""
import pytest
from comfy.cli_args import args, LatentPreviewMethod
from latent_preview import set_preview_method, default_preview_method
class TestLatentPreviewMethodFromString:
"""Test LatentPreviewMethod.from_string() classmethod."""
@pytest.mark.parametrize("value,expected", [
("auto", LatentPreviewMethod.Auto),
("latent2rgb", LatentPreviewMethod.Latent2RGB),
("taesd", LatentPreviewMethod.TAESD),
("none", LatentPreviewMethod.NoPreviews),
])
def test_valid_values_return_enum(self, value, expected):
"""Valid string values should return corresponding enum."""
assert LatentPreviewMethod.from_string(value) == expected
@pytest.mark.parametrize("invalid", [
"invalid",
"TAESD", # Case sensitive
"AUTO", # Case sensitive
"Latent2RGB", # Case sensitive
"latent",
"",
"default", # default is special, not a method
])
def test_invalid_values_return_none(self, invalid):
"""Invalid string values should return None."""
assert LatentPreviewMethod.from_string(invalid) is None
class TestLatentPreviewMethodEnumValues:
"""Test LatentPreviewMethod enum has expected values."""
def test_enum_values(self):
"""Verify enum values match expected strings."""
assert LatentPreviewMethod.NoPreviews.value == "none"
assert LatentPreviewMethod.Auto.value == "auto"
assert LatentPreviewMethod.Latent2RGB.value == "latent2rgb"
assert LatentPreviewMethod.TAESD.value == "taesd"
def test_enum_count(self):
"""Verify exactly 4 preview methods exist."""
assert len(LatentPreviewMethod) == 4
class TestSetPreviewMethod:
"""Test set_preview_method() function from latent_preview.py."""
def setup_method(self):
"""Store original value before each test."""
self.original = args.preview_method
def teardown_method(self):
"""Restore original value after each test."""
args.preview_method = self.original
def test_override_with_taesd(self):
"""'taesd' should set args.preview_method to TAESD."""
set_preview_method("taesd")
assert args.preview_method == LatentPreviewMethod.TAESD
def test_override_with_latent2rgb(self):
"""'latent2rgb' should set args.preview_method to Latent2RGB."""
set_preview_method("latent2rgb")
assert args.preview_method == LatentPreviewMethod.Latent2RGB
def test_override_with_auto(self):
"""'auto' should set args.preview_method to Auto."""
set_preview_method("auto")
assert args.preview_method == LatentPreviewMethod.Auto
def test_override_with_none_value(self):
"""'none' should set args.preview_method to NoPreviews."""
set_preview_method("none")
assert args.preview_method == LatentPreviewMethod.NoPreviews
def test_default_restores_original(self):
"""'default' should restore to default_preview_method."""
# First override to something else
set_preview_method("taesd")
assert args.preview_method == LatentPreviewMethod.TAESD
# Then use 'default' to restore
set_preview_method("default")
assert args.preview_method == default_preview_method
def test_none_param_restores_original(self):
"""None parameter should restore to default_preview_method."""
# First override to something else
set_preview_method("taesd")
assert args.preview_method == LatentPreviewMethod.TAESD
# Then use None to restore
set_preview_method(None)
assert args.preview_method == default_preview_method
def test_empty_string_restores_original(self):
"""Empty string should restore to default_preview_method."""
set_preview_method("taesd")
set_preview_method("")
assert args.preview_method == default_preview_method
def test_invalid_value_restores_original(self):
"""Invalid value should restore to default_preview_method."""
set_preview_method("taesd")
set_preview_method("invalid_method")
assert args.preview_method == default_preview_method
def test_case_sensitive_invalid_restores(self):
"""Case-mismatched values should restore to default."""
set_preview_method("taesd")
set_preview_method("TAESD") # Wrong case
assert args.preview_method == default_preview_method
class TestDefaultPreviewMethod:
"""Test default_preview_method module variable."""
def test_default_is_not_none(self):
"""default_preview_method should not be None."""
assert default_preview_method is not None
def test_default_is_enum_member(self):
"""default_preview_method should be a LatentPreviewMethod enum."""
assert isinstance(default_preview_method, LatentPreviewMethod)
def test_default_matches_args_initial(self):
"""default_preview_method should match CLI default or user setting."""
# This tests that default_preview_method was captured at module load
# After set_preview_method(None), args should equal default
original = args.preview_method
set_preview_method("taesd")
set_preview_method(None)
assert args.preview_method == default_preview_method
args.preview_method = original
class TestArgsPreviewMethodModification:
"""Test args.preview_method can be modified correctly."""
def setup_method(self):
"""Store original value before each test."""
self.original = args.preview_method
def teardown_method(self):
"""Restore original value after each test."""
args.preview_method = self.original
def test_args_accepts_all_enum_values(self):
"""args.preview_method should accept all LatentPreviewMethod values."""
for method in LatentPreviewMethod:
args.preview_method = method
assert args.preview_method == method
def test_args_modification_and_restoration(self):
"""args.preview_method should be modifiable and restorable."""
original = args.preview_method
args.preview_method = LatentPreviewMethod.TAESD
assert args.preview_method == LatentPreviewMethod.TAESD
args.preview_method = original
assert args.preview_method == original
class TestExecutionFlow:
"""Test the execution flow pattern used in execution.py."""
def setup_method(self):
"""Store original value before each test."""
self.original = args.preview_method
def teardown_method(self):
"""Restore original value after each test."""
args.preview_method = self.original
def test_sequential_executions_with_different_methods(self):
"""Simulate multiple queue executions with different preview methods."""
# Execution 1: taesd
set_preview_method("taesd")
assert args.preview_method == LatentPreviewMethod.TAESD
# Execution 2: none
set_preview_method("none")
assert args.preview_method == LatentPreviewMethod.NoPreviews
# Execution 3: default (restore)
set_preview_method("default")
assert args.preview_method == default_preview_method
# Execution 4: auto
set_preview_method("auto")
assert args.preview_method == LatentPreviewMethod.Auto
# Execution 5: no override (None)
set_preview_method(None)
assert args.preview_method == default_preview_method
def test_override_then_default_pattern(self):
"""Test the pattern: override -> execute -> next call restores."""
# First execution with override
set_preview_method("latent2rgb")
assert args.preview_method == LatentPreviewMethod.Latent2RGB
# Second execution without override restores default
set_preview_method(None)
assert args.preview_method == default_preview_method
def test_extra_data_simulation(self):
"""Simulate extra_data.get('preview_method') patterns."""
# Simulate: extra_data = {"preview_method": "taesd"}
extra_data = {"preview_method": "taesd"}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == LatentPreviewMethod.TAESD
# Simulate: extra_data = {}
extra_data = {}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == default_preview_method
# Simulate: extra_data = {"preview_method": "default"}
extra_data = {"preview_method": "default"}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == default_preview_method
class TestRealWorldScenarios:
"""Tests using real-world prompt data patterns."""
def setup_method(self):
"""Store original value before each test."""
self.original = args.preview_method
def teardown_method(self):
"""Restore original value after each test."""
args.preview_method = self.original
def test_captured_prompt_without_preview_method(self):
"""
Test with captured prompt that has no preview_method.
Based on: tests-unit/execution_test/fixtures/default_prompt.json
"""
# Real captured extra_data structure (preview_method absent)
extra_data = {
"extra_pnginfo": {"workflow": {}},
"client_id": "271314f0dabd48e5aaa488ed7a4ceb0d",
"create_time": 1765416558179
}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == default_preview_method
def test_captured_prompt_with_preview_method_taesd(self):
"""Test captured prompt with preview_method: taesd."""
extra_data = {
"extra_pnginfo": {"workflow": {}},
"client_id": "271314f0dabd48e5aaa488ed7a4ceb0d",
"preview_method": "taesd"
}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == LatentPreviewMethod.TAESD
def test_captured_prompt_with_preview_method_none(self):
"""Test captured prompt with preview_method: none (disable preview)."""
extra_data = {
"extra_pnginfo": {"workflow": {}},
"client_id": "test-client",
"preview_method": "none"
}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == LatentPreviewMethod.NoPreviews
def test_captured_prompt_with_preview_method_latent2rgb(self):
"""Test captured prompt with preview_method: latent2rgb."""
extra_data = {
"extra_pnginfo": {"workflow": {}},
"client_id": "test-client",
"preview_method": "latent2rgb"
}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == LatentPreviewMethod.Latent2RGB
def test_captured_prompt_with_preview_method_auto(self):
"""Test captured prompt with preview_method: auto."""
extra_data = {
"extra_pnginfo": {"workflow": {}},
"client_id": "test-client",
"preview_method": "auto"
}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == LatentPreviewMethod.Auto
def test_captured_prompt_with_preview_method_default(self):
"""Test captured prompt with preview_method: default (use CLI setting)."""
# First set to something else
set_preview_method("taesd")
assert args.preview_method == LatentPreviewMethod.TAESD
# Then simulate a prompt with "default"
extra_data = {
"extra_pnginfo": {"workflow": {}},
"client_id": "test-client",
"preview_method": "default"
}
set_preview_method(extra_data.get("preview_method"))
assert args.preview_method == default_preview_method
def test_sequential_queue_with_different_preview_methods(self):
"""
Simulate real queue scenario: multiple prompts with different settings.
This tests the actual usage pattern in ComfyUI.
"""
# Queue 1: User wants TAESD preview
extra_data_1 = {"client_id": "client-1", "preview_method": "taesd"}
set_preview_method(extra_data_1.get("preview_method"))
assert args.preview_method == LatentPreviewMethod.TAESD
# Queue 2: User wants no preview (faster execution)
extra_data_2 = {"client_id": "client-2", "preview_method": "none"}
set_preview_method(extra_data_2.get("preview_method"))
assert args.preview_method == LatentPreviewMethod.NoPreviews
# Queue 3: User doesn't specify (use server default)
extra_data_3 = {"client_id": "client-3"}
set_preview_method(extra_data_3.get("preview_method"))
assert args.preview_method == default_preview_method
# Queue 4: User explicitly wants default
extra_data_4 = {"client_id": "client-4", "preview_method": "default"}
set_preview_method(extra_data_4.get("preview_method"))
assert args.preview_method == default_preview_method
# Queue 5: User wants latent2rgb
extra_data_5 = {"client_id": "client-5", "preview_method": "latent2rgb"}
set_preview_method(extra_data_5.get("preview_method"))
assert args.preview_method == LatentPreviewMethod.Latent2RGB

View File

@ -0,0 +1,358 @@
"""
E2E tests for Queue-specific Preview Method Override feature.
Tests actual execution with different preview_method values.
Requires a running ComfyUI server with models.
Usage:
COMFYUI_SERVER=http://localhost:8988 pytest test_preview_method_e2e.py -v -m preview_method
Note:
These tests execute actual image generation and wait for completion.
Tests verify preview image transmission based on preview_method setting.
"""
import os
import json
import pytest
import uuid
import time
import random
import websocket
import urllib.request
from pathlib import Path
# Server configuration
SERVER_URL = os.environ.get("COMFYUI_SERVER", "http://localhost:8988")
SERVER_HOST = SERVER_URL.replace("http://", "").replace("https://", "")
# Use existing inference graph fixture
GRAPH_FILE = Path(__file__).parent.parent / "inference" / "graphs" / "default_graph_sdxl1_0.json"
def is_server_running() -> bool:
"""Check if ComfyUI server is running."""
try:
request = urllib.request.Request(f"{SERVER_URL}/system_stats")
with urllib.request.urlopen(request, timeout=2.0):
return True
except Exception:
return False
def prepare_graph_for_test(graph: dict, steps: int = 5) -> dict:
"""Prepare graph for testing: randomize seeds and reduce steps."""
adapted = json.loads(json.dumps(graph)) # Deep copy
for node_id, node in adapted.items():
inputs = node.get("inputs", {})
# Handle both "seed" and "noise_seed" (used by KSamplerAdvanced)
if "seed" in inputs:
inputs["seed"] = random.randint(0, 2**32 - 1)
if "noise_seed" in inputs:
inputs["noise_seed"] = random.randint(0, 2**32 - 1)
# Reduce steps for faster testing (default 20 -> 5)
if "steps" in inputs:
inputs["steps"] = steps
return adapted
# Alias for backward compatibility
randomize_seed = prepare_graph_for_test
class PreviewMethodClient:
"""Client for testing preview_method with WebSocket execution tracking."""
def __init__(self, server_address: str):
self.server_address = server_address
self.client_id = str(uuid.uuid4())
self.ws = None
def connect(self):
"""Connect to WebSocket."""
self.ws = websocket.WebSocket()
self.ws.settimeout(120) # 2 minute timeout for sampling
self.ws.connect(f"ws://{self.server_address}/ws?clientId={self.client_id}")
def close(self):
"""Close WebSocket connection."""
if self.ws:
self.ws.close()
def queue_prompt(self, prompt: dict, extra_data: dict = None) -> dict:
"""Queue a prompt and return response with prompt_id."""
data = {
"prompt": prompt,
"client_id": self.client_id,
"extra_data": extra_data or {}
}
req = urllib.request.Request(
f"http://{self.server_address}/prompt",
data=json.dumps(data).encode("utf-8"),
headers={"Content-Type": "application/json"}
)
return json.loads(urllib.request.urlopen(req).read())
def wait_for_execution(self, prompt_id: str, timeout: float = 120.0) -> dict:
"""
Wait for execution to complete via WebSocket.
Returns:
dict with keys: completed, error, preview_count, execution_time
"""
result = {
"completed": False,
"error": None,
"preview_count": 0,
"execution_time": 0.0
}
start_time = time.time()
self.ws.settimeout(timeout)
try:
while True:
out = self.ws.recv()
elapsed = time.time() - start_time
if isinstance(out, str):
message = json.loads(out)
msg_type = message.get("type")
data = message.get("data", {})
if data.get("prompt_id") != prompt_id:
continue
if msg_type == "executing":
if data.get("node") is None:
# Execution complete
result["completed"] = True
result["execution_time"] = elapsed
break
elif msg_type == "execution_error":
result["error"] = data
result["execution_time"] = elapsed
break
elif msg_type == "progress":
# Progress update during sampling
pass
elif isinstance(out, bytes):
# Binary data = preview image
result["preview_count"] += 1
except websocket.WebSocketTimeoutException:
result["error"] = "Timeout waiting for execution"
result["execution_time"] = time.time() - start_time
return result
def load_graph() -> dict:
"""Load the SDXL graph fixture with randomized seed."""
with open(GRAPH_FILE) as f:
graph = json.load(f)
return randomize_seed(graph) # Avoid caching
# Skip all tests if server is not running
pytestmark = [
pytest.mark.skipif(
not is_server_running(),
reason=f"ComfyUI server not running at {SERVER_URL}"
),
pytest.mark.preview_method,
pytest.mark.execution,
]
@pytest.fixture
def client():
"""Create and connect a test client."""
c = PreviewMethodClient(SERVER_HOST)
c.connect()
yield c
c.close()
@pytest.fixture
def graph():
"""Load the test graph."""
return load_graph()
class TestPreviewMethodExecution:
"""Test actual execution with different preview methods."""
def test_execution_with_latent2rgb(self, client, graph):
"""
Execute with preview_method=latent2rgb.
Should complete and potentially receive preview images.
"""
extra_data = {"preview_method": "latent2rgb"}
response = client.queue_prompt(graph, extra_data)
assert "prompt_id" in response
result = client.wait_for_execution(response["prompt_id"])
# Should complete (may error if model missing, but that's separate)
assert result["completed"] or result["error"] is not None
# Execution should take some time (sampling)
if result["completed"]:
assert result["execution_time"] > 0.5, "Execution too fast - likely didn't run"
# latent2rgb should produce previews
print(f"latent2rgb: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
def test_execution_with_taesd(self, client, graph):
"""
Execute with preview_method=taesd.
TAESD provides higher quality previews.
"""
extra_data = {"preview_method": "taesd"}
response = client.queue_prompt(graph, extra_data)
assert "prompt_id" in response
result = client.wait_for_execution(response["prompt_id"])
assert result["completed"] or result["error"] is not None
if result["completed"]:
assert result["execution_time"] > 0.5
# taesd should also produce previews
print(f"taesd: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
def test_execution_with_none_preview(self, client, graph):
"""
Execute with preview_method=none.
No preview images should be generated.
"""
extra_data = {"preview_method": "none"}
response = client.queue_prompt(graph, extra_data)
assert "prompt_id" in response
result = client.wait_for_execution(response["prompt_id"])
assert result["completed"] or result["error"] is not None
if result["completed"]:
# With "none", should receive no preview images
assert result["preview_count"] == 0, \
f"Expected no previews with 'none', got {result['preview_count']}"
print(f"none: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
def test_execution_with_default(self, client, graph):
"""
Execute with preview_method=default.
Should use server's CLI default setting.
"""
extra_data = {"preview_method": "default"}
response = client.queue_prompt(graph, extra_data)
assert "prompt_id" in response
result = client.wait_for_execution(response["prompt_id"])
assert result["completed"] or result["error"] is not None
if result["completed"]:
print(f"default: {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
def test_execution_without_preview_method(self, client, graph):
"""
Execute without preview_method in extra_data.
Should use server's default preview method.
"""
extra_data = {} # No preview_method
response = client.queue_prompt(graph, extra_data)
assert "prompt_id" in response
result = client.wait_for_execution(response["prompt_id"])
assert result["completed"] or result["error"] is not None
if result["completed"]:
print(f"(no override): {result['preview_count']} previews in {result['execution_time']:.2f}s") # noqa: T201
class TestPreviewMethodComparison:
"""Compare preview behavior between different methods."""
def test_none_vs_latent2rgb_preview_count(self, client, graph):
"""
Compare preview counts: 'none' should have 0, others should have >0.
This is the key verification that preview_method actually works.
"""
results = {}
# Run with none (randomize seed to avoid caching)
graph_none = randomize_seed(graph)
extra_data_none = {"preview_method": "none"}
response = client.queue_prompt(graph_none, extra_data_none)
results["none"] = client.wait_for_execution(response["prompt_id"])
# Run with latent2rgb (randomize seed again)
graph_rgb = randomize_seed(graph)
extra_data_rgb = {"preview_method": "latent2rgb"}
response = client.queue_prompt(graph_rgb, extra_data_rgb)
results["latent2rgb"] = client.wait_for_execution(response["prompt_id"])
# Verify both completed
assert results["none"]["completed"], f"'none' execution failed: {results['none']['error']}"
assert results["latent2rgb"]["completed"], f"'latent2rgb' execution failed: {results['latent2rgb']['error']}"
# Key assertion: 'none' should have 0 previews
assert results["none"]["preview_count"] == 0, \
f"'none' should have 0 previews, got {results['none']['preview_count']}"
# 'latent2rgb' should have at least 1 preview (depends on steps)
assert results["latent2rgb"]["preview_count"] > 0, \
f"'latent2rgb' should have >0 previews, got {results['latent2rgb']['preview_count']}"
print("\nPreview count comparison:") # noqa: T201
print(f" none: {results['none']['preview_count']} previews") # noqa: T201
print(f" latent2rgb: {results['latent2rgb']['preview_count']} previews") # noqa: T201
class TestPreviewMethodSequential:
"""Test sequential execution with different preview methods."""
def test_sequential_different_methods(self, client, graph):
"""
Execute multiple prompts sequentially with different preview methods.
Each should complete independently with correct preview behavior.
"""
methods = ["latent2rgb", "none", "default"]
results = []
for method in methods:
# Randomize seed for each execution to avoid caching
graph_run = randomize_seed(graph)
extra_data = {"preview_method": method}
response = client.queue_prompt(graph_run, extra_data)
result = client.wait_for_execution(response["prompt_id"])
results.append({
"method": method,
"completed": result["completed"],
"preview_count": result["preview_count"],
"execution_time": result["execution_time"],
"error": result["error"]
})
# All should complete or have clear errors
for r in results:
assert r["completed"] or r["error"] is not None, \
f"Method {r['method']} neither completed nor errored"
# "none" should have zero previews if completed
none_result = next(r for r in results if r["method"] == "none")
if none_result["completed"]:
assert none_result["preview_count"] == 0, \
f"'none' should have 0 previews, got {none_result['preview_count']}"
print("\nSequential execution results:") # noqa: T201
for r in results:
status = "" if r["completed"] else f"✗ ({r['error']})"
print(f" {r['method']}: {status}, {r['preview_count']} previews, {r['execution_time']:.2f}s") # noqa: T201