Compare commits

..

44 Commits

Author SHA1 Message Date
eason
a3d2d35979
Merge 4b97d167f1 into 9c34f5f36a 2026-05-06 14:03:38 +02:00
Comfy Org PR Bot
9c34f5f36a
Bump comfyui-frontend-package to 1.43.17 (#13723)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Alexander Brown <DrJKL0424@gmail.com>
2026-05-05 22:22:48 -07:00
Talmaj
78b3096bf3
Void model - pass 1 & 2 (CORE-38) (#13403) 2026-05-05 19:59:04 -07:00
Luke Mino-Altherr
2b63add0ad
fix: return millisecond timestamps from get_file_info() (#12996) 2026-05-06 10:56:09 +08:00
iChrist
160b95f75c
Update language options in nodes_ace.py (#12578)
* Update language options in nodes_ace.py

Modified it to include all 51 language options ace-step1.5 supports instead of the original 23 comfyui had.

* re-arrange list by popularity

changed order of the languages to be ordered by popularity

en is default 
unknown is last

* Update comfy_extras/nodes_ace.py
2026-05-05 19:47:57 -07:00
comfyanonymous
c168960a12
First step of supporting save filenames without trailing _ (#13722)
get_save_image_path now properly supports filenames without
trailing underscores.

This will be the saving behavior when using a mix of save image nodes using the old and the new format.

ComfyUI_00001_.png
ComfyUI_00002.png
ComfyUI_00003.png
ComfyUI_00004_.png
2026-05-05 17:00:11 -07:00
drozbay
e5369c0eec
feat: Context windows - add causal_window_fix to improve blending of context windows (CORE-100) (#13563)
* Context windows: add causal_window_fix toggle

* Fix slice_cond to correctly handle causal anchor index for temporal offsets
2026-05-05 16:40:53 -07:00
drozbay
1655f8089a
Add temporal_downscale_ratio to LatentFormat (#13702)
Co-authored-by: ozbayb <17261091+ozbayb@users.noreply.github.com>
Co-authored-by: Alexis Rolland <alexisrolland@hotmail.com>
Co-authored-by: Jukka Seppänen <40791699+kijai@users.noreply.github.com>
Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2026-05-05 16:30:00 -07:00
Matt Miller
89014792c9
feat: add cloud-specific fields to OSS openapi.yaml as nullable (#13623)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
* feat: add cloud-specific fields to OSS openapi.yaml as nullable

Add cross-runtime fields with x-runtime: [cloud] extension and [cloud-only]
description prefix per the convention established in BE-613. All new fields
are nullable and not in required arrays, so they are purely additive.

/api/features response:
- max_upload_size (integer, int64)
- free_tier_credits (integer, int32)
- posthog_api_host (string, uri)
- max_concurrent_jobs (integer, int32)
- workflow_templates_version (string)
- workflow_templates_source (string, enum)

PromptRequest schema:
- workflow_id (string, uuid)
- workflow_version_id (string, uuid)

POST /api/assets:
- id field (uuid) on multipart/form-data for idempotent creation
- application/json alternate content-type for URL-based uploads

POST /api/assets/from-hash:
- mime_type (string) to preserve type without re-inspection

PUT /api/assets/{id}:
- mime_type (string) for overriding auto-detection

GET /api/assets additional query parameters:
- job_ids (string) — filter by associated job UUIDs
- include_public (boolean) — include workspace-public assets
- asset_hash (string) — filter by exact content hash

Resolves: BE-613
Blocks: BE-364, BE-361, BE-363

Co-authored-by: Matt Miller <MillerMedia@users.noreply.github.com>

* fix(openapi): address CodeRabbit feedback (BE-613)

- max_upload_size is set in both runtimes via SERVER_FEATURE_FLAGS;
  drop the cloud-only / nullable tagging.
- Require `url` on the application/json POST /api/assets body so the
  contract is enforceable by validators and codegen.

---------

Co-authored-by: Matt Miller <MillerMedia@users.noreply.github.com>
2026-05-05 14:20:09 -07:00
Jedrzej Kosinski
431fadb520
fix(api-io): serialize MultiCombo multi_select as object config (#13484)
* fix(api-io): serialize MultiCombo multi_select as object config
* fix: remove dead code and redundant top-level keys from MultiCombo serialization
* fix: correct skip warning to mention comfy_entrypoint, remove nonexistent NODES_LIST
* fix: validate MultiCombo list values against options individually
* fix: gate multiselect validation on schema config, improve error message, add tests

---------

Co-authored-by: Ni-zav <ni-zav@users.noreply.github.com>
Co-authored-by: guill <jacob.e.segal@gmail.com>
2026-05-05 13:58:32 -07:00
Matt Miller
1ac60da2c9
Add Spectral lint CI gate for openapi.yaml (#13410)
* Add Spectral lint CI gate for openapi.yaml

Adds a blocking Spectral lint check that runs on PRs touching
openapi.yaml or the ruleset itself. The ruleset mirrors the one used
for other Comfy-Org service specs: spectral:oas plus conventions for
snake_case properties, camelCase operationIds, and response/schema
shape. Gate runs at --fail-severity=error, which the spec currently
passes with zero errors (a small number of non-blocking
warnings/hints remain for WebSocket 101 responses, the existing loose
error schema, and two snake_case wire fields).

* ci: set least-privilege contents:read permissions on openapi-lint workflow

Per CodeRabbit review on #13410. The job only checks out the repo and
runs Spectral, so contents:read is sufficient and avoids inheriting any
permissive repo/org default token scope.

---------

Co-authored-by: guill <jacob.e.segal@gmail.com>
2026-05-05 13:21:36 -07:00
drozbay
41d73ad180
fix(audio): drop sample_rate key from LTXVEmptyLatentAudio (CORE-157) (#13716) 2026-05-05 11:33:16 -07:00
THE MACHINE
ea6880b04b
Fix Content-Disposition header missing 'attachment;' prefix (#13093)
Add missing 'attachment;' directive to Content-Disposition headers in
server.py to ensure browsers properly download files instead of
attempting to display them inline.

Fixes 4 instances in the file download endpoint.

Co-authored-by: guill <jacob.e.segal@gmail.com>
2026-05-05 11:00:03 -07:00
Alexis Rolland
639f631a08
chore: Update display names and categories for text nodes (CORE-155) (#13712)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-05-05 22:31:24 +08:00
Daxiong (Lin)
d794b62939
Update workflow templates to v0.9.69 (#13714)
* chore: update workflow templates to v0.9.69

* Update comfyui-workflow-templates to version 0.9.70

* Downgrade comfyui-workflow-templates to 0.9.69

---------

Co-authored-by: Alexander Piskun <13381981+bigcat88@users.noreply.github.com>
2026-05-05 16:57:27 +03:00
Alexander Piskun
6917bce128
[Partner Nodes] add Gpt 5.5 and 5.5-pro LLM models (#13673)
* feat(api-nodes): add Gpt 5.5 and 5.5-pro LLM models

Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-05-05 16:53:19 +03:00
Alexander Piskun
c55ff85243
feat(api-nodes): add Luma UNI-1 models (#13614)
Signed-off-by: bigcat88 <bigcat88@icloud.com>
Co-authored-by: Alexis Rolland <alexisrolland@hotmail.com>
2026-05-05 16:49:07 +03:00
Alvin Tang
8d75211300
fix: SplitImageToTileList and ImageMergeTileList to use tile_height for vertical stride minimum (#12882) 2026-05-05 20:29:11 +08:00
Talmaj
fed8d5efa6
feat: Auto-regressive video generation (CORE-25) (#13082)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-05-04 21:01:22 -07:00
comfyanonymous
9aef025fb0
Document core release frequency is now ~2 weeks. (#13710) 2026-05-04 20:45:48 -07:00
Jedrzej Kosinski
e758594e3b
Add deploy environment header (Comfy-Env) to partner node API calls (#13425) 2026-05-04 20:17:56 -07:00
Jedrzej Kosinski
ae457da84b
feat: add generic --feature-flag CLI arg and --list-feature-flags registry (#13685) 2026-05-04 19:50:26 -07:00
Matt Miller
413e250ccd
spec: add workflow_id / workflow_version_id to PromptRequest with x-runtime tag (#13709)
Adds two optional, nullable UUID fields to PromptRequest for runtimes
that wrap workflow execution in a workflow-version entity (the
hosted-cloud runtime does this; local ComfyUI does not). Both fields
are tagged `x-runtime: [cloud]` to mark them as runtime-specific —
local ComfyUI returns `null` (or omits them entirely) and that's
correct behavior, not drift.

## Why these fields belong in the OSS spec

Hosted-cloud's frontend and backend share `openapi.yaml` as their
single source of truth via auto-generated client types. Without the
fields declared in the spec, the cloud runtime has to either:

  1. Hand-edit a vendored copy of openapi.yaml (drift between vendor
     and upstream — unsustainable).
  2. Maintain a separate cloud-only spec file (forks the contract,
     defeats the point of a shared OSS spec).

Both options have been tried and both produce maintenance pain. The
shape that scales is: cloud-only fields live in OSS spec under their
intended path, declared nullable, with an explicit `x-runtime` tag so
local-only readers can ignore them programmatically and human readers
can see what each runtime populates.

## About the `x-runtime` extension

This is the first use of `x-runtime` in this spec. Convention:

  - `x-runtime: [cloud]` — only the hosted-cloud runtime populates the
    field; local returns null or omits.
  - `x-runtime: [local]` — only local populates; cloud returns null.
  - Tag absent — both runtimes populate the field (the default).

This is a vendor extension (`x-` prefix) and is ignored by spec
validators that don't recognize it, including `kin-openapi`. Local
clients reading the spec see two extra optional nullable fields, which
is forward-compatible with all existing readers.

## What this does not change

  - No Python code changes. `PromptRequest` already accepts arbitrary
    optional fields (`extra_data: additionalProperties: true` on the
    same schema is a stronger guarantee). The Python server already
    silently accepts and ignores both fields today.
  - No required-fields change. Both fields stay outside `required`,
    so older clients that don't know about them keep validating.
  - No nullability widening on existing fields.

## Verification

  - YAML parses (`yaml.safe_load`).
  - `kin-openapi` `loader.LoadFromFile` accepts the modified spec.
  - `openapi3filter.ValidateRequest` on a PromptRequest with both
    fields set to `null`, set to a valid UUID, or omitted — all pass.
2026-05-04 18:59:48 -07:00
Matt Miller
35819e35a8
fix(spec): mark DeviceStats.index and NodeInfo.essentials_category as nullable (#13706)
* fix(spec): mark DeviceStats.index and NodeInfo.essentials_category as nullable

Two fields in openapi.yaml are declared as required/non-nullable but
the Python implementation legitimately returns `null` for them, so any
client that response-validates against the spec will fail.

`DeviceStats.index` (used by GET /api/system_stats):
- server.py emits `"index": device.index` unconditionally
- For the CPU device (--cpu mode), `torch.device("cpu").index` is `None`
- → JSON response includes `"index": null` for CPU devices

`NodeInfo.essentials_category` (used by GET /api/object_info):
- The V3 schema-based path (comfy_api/latest/_io.py:1654) unconditionally
  passes `essentials_category=self.essentials_category` into NodeInfoV1
  and serializes via dataclasses.asdict(), so the key is always present
- Schema's `essentials_category` defaults to `None` for nodes that
  don't set it in `define_schema` (e.g. the APG node)
- → JSON response includes `"essentials_category": null` for those nodes
- (The V1 path in server.py uses `hasattr` and so omits the key
  entirely when not set, but the V3 path is the one that produces nulls)

Both fields keep their existing `required` status — they're always
present in the response, the value is just nullable. Descriptions
expanded to spell out when `null` is expected.

* docs(spec): clarify essentials_category presence rules

The previous description said "null for nodes that don't set
ESSENTIALS_CATEGORY (V1)" — that's wrong. server.py:739-740 uses
`hasattr` and OMITS the key when the V1 attribute isn't defined; null
only happens if the attribute is explicitly set to None. Spell out
all three legal shapes (string / null / absent) and which path
produces which.
2026-05-04 18:28:21 -07:00
Alexis Rolland
15a4494a4e
chore: Update display names and categories (CORE-151) (#13693)
* Standardize DEPRECATED label in node display name

* Promote category image/video to root level video/

* Update images and masks names and categories
2026-05-04 17:37:25 -07:00
rattus
1265955b34
ops: handle multi-compute of the same weight (#13705)
If the same weight is used multiple times within the same prefetch
window, it should only apply compute state mutations once. Mark the
weight as fully resident on the first pass accordingly.
2026-05-04 16:40:57 -07:00
rattus
1ac78180b3
make control-net load order deterministic (#13701)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Make this deterministic so speeds dont change base of load order. Load
them in reverse order so whatever the caller lists first is the top
priority.
2026-05-04 12:58:06 -07:00
rattus
c47633f3be
prefetch: guard against no offload (#13703)
cast_ will return no stream if there is no work to do. guard against
this is the consume logic.
2026-05-04 12:56:05 -07:00
Jukka Seppänen
c33d26c283
fix: Proper memory estimation for frame interpolation when not using dynamic VRAM (#13698) 2026-05-04 20:20:40 +03:00
Soof Golan
f3ea976cba
Fix a1111 typo in extra_model_paths.yaml (#2720)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-05-04 16:01:46 +08:00
Alexis Rolland
5538f62b0b
fix: Update ColorTransfer node ref_image to be mandatory (#13691)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-05-04 12:33:11 +08:00
Jedrzej Kosinski
2806163f6e
Default control_after_generate to fixed in PrimitiveInt node (#13690) 2026-05-04 07:21:34 +08:00
comfyanonymous
cea8d0925f
Refactor LoadImageMask to use LoadImage code. (#13687)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-05-03 16:18:27 -04:00
Silver
b138133ffa
Enable triton comfy kitchen via cli-arg (#12730) 2026-05-03 14:07:21 -04:00
Jukka Seppänen
025e6792ee
Batch broadcasting in JoinImageWithAlpha node (#13686)
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Generate Pydantic Stubs from api.comfy.org / generate-models (push) Has been cancelled
* Batch broadcasting in JoinImageWithAlpha node
2026-05-03 16:30:00 +03:00
Luke Mino-Altherr
867b8d2408
fix: gracefully handle port-in-use error on server startup (#13001)
Catch EADDRINUSE OSError when binding the TCP site and exit with a clear error message instead of an unhandled traceback.
2026-05-03 20:44:20 +08:00
Alexis Rolland
d0f0b15cf5
Update ComfyUI screenshot in README (#13683)
Update ComfyUI screenshot to showcase a more modern workflow
2026-05-03 18:48:58 +08:00
Alexis Rolland
b5bb83c964
Fix issue blend images with alpha (#13615)
Make ImageBlend and ImageCompositeMasked nodes handle images with different channel counts
2026-05-03 18:17:08 +08:00
Alexis Rolland
f6d5068ac0
Update README (#13679)
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
Updated the README to include a new screenshot, improved description and add Ernie Image to supported models.
2026-05-03 12:20:17 +08:00
Jukka Seppänen
be95871adc
feat: Gemma4 text generation support (CORE-30) (#13376)
* initial gemma4 support

* parity with reference implementation

outputs can 100% match transformers with same sdpa flags, checkpoint this and then optimize

* Cleanup, video fixes

* cleanup, enable fused rms norm by default

* update comment

* Cleanup

* Update sd.py

* Various fixes

* Add fp8 scaled embedding support

* small fixes

* Translate think tokens

* Fix image encoder attention mask type

So it works with basic attention

* Handle thinking tokens different only for Gemma4

* Code cleanup

* Update nodes_textgen.py

* Use embed scale class instead of buffer

Slight difference to HF, but technically more accurate and simpler code

* Default to fused rms_norm

* Update gemma4.py
2026-05-02 22:46:15 -04:00
Alexander Piskun
f756d801a1
[Partner Nodes] Topaz Astra 2 model (#13672)
* feat(api-nodes): add Topaz Astra 2 model

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* feat(api-nodes): make Astra 2 the default Topaz upscaler model

Reorder UPSCALER_MODELS_MAP and the upscaler_model dynamic combo so
"Astra 2" appears first, surfacing it as the default selection.

---------

Signed-off-by: bigcat88 <bigcat88@icloud.com>
Co-authored-by: Marwan Mostafa <marawan206@gmail.com>
2026-05-02 19:29:00 -07:00
Daxiong (Lin)
1d23a875ed
chore: update workflow templates to v0.9.68 (#13678) 2026-05-03 10:06:55 +08:00
comfyanonymous
ef6722f6be
Some cleanups to the load image node. (#13677) 2026-05-02 20:34:27 -04:00
rattus
783782d5d7
Implement block prefetch + Lora Async load + and adopt in LTX (Speedup!) (CORE-111) (#13618)
* mm: Use Aimdo raw allocator for cast buffers

pytorch manages allocation of growing buffers on streams poorly. Pyt
has no windows support for the expandable segments allocator (which is
the right tool for this job), while also segmenting the memory by
stream such that it can be generally re-used. So kick the problem to
aimdo which can just grow a virtual region thats freed per stream.

* plan

* ops: move cpu handler up to the caller

* ops: split up prefetch from weight prep block prefetching API

Split up the casting and weight formating/lora stuff in prep for
arbitrary prefetch support.

* ops: implement block prefetching API

allow a model to construct a prefetch list and operate it for increased
async offload.

* ltxv2: Implement block prefetching

* Implement lora async offload

Implement async offload of loras.
2026-05-02 19:23:24 -04:00
74 changed files with 4914 additions and 316 deletions

31
.github/workflows/openapi-lint.yml vendored Normal file
View File

@ -0,0 +1,31 @@
name: OpenAPI Lint
on:
pull_request:
paths:
- 'openapi.yaml'
- '.spectral.yaml'
- '.github/workflows/openapi-lint.yml'
permissions:
contents: read
jobs:
spectral:
name: Run Spectral
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: '20'
- name: Install Spectral
run: npm install -g @stoplight/spectral-cli@6
- name: Lint openapi.yaml
run: spectral lint openapi.yaml --ruleset .spectral.yaml --fail-severity=error

1
.gitignore vendored
View File

@ -23,3 +23,4 @@ web_custom_versions/
.DS_Store .DS_Store
filtered-openapi.yaml filtered-openapi.yaml
uv.lock uv.lock
.comfy_environment

91
.spectral.yaml Normal file
View File

@ -0,0 +1,91 @@
extends:
- spectral:oas
# Severity levels: error, warn, info, hint, off
# Rules from the built-in "spectral:oas" ruleset are active by default.
# Below we tune severity and add custom rules for our conventions.
#
# This ruleset mirrors Comfy-Org/cloud/.spectral.yaml so specs across the
# organization are linted against a single consistent standard.
rules:
# -----------------------------------------------------------------------
# Built-in rule severity overrides
# -----------------------------------------------------------------------
operation-operationId: error
operation-description: warn
operation-tag-defined: error
info-contact: off
info-description: warn
no-eval-in-markdown: error
no-$ref-siblings: error
# -----------------------------------------------------------------------
# Custom rules: naming conventions
# -----------------------------------------------------------------------
# Property names should be snake_case
property-name-snake-case:
description: Property names must be snake_case
severity: warn
given: "$.components.schemas.*.properties[*]~"
then:
function: pattern
functionOptions:
match: "^[a-z][a-z0-9]*(_[a-z0-9]+)*$"
# Operation IDs should be camelCase
operation-id-camel-case:
description: Operation IDs must be camelCase
severity: warn
given: "$.paths.*.*.operationId"
then:
function: pattern
functionOptions:
match: "^[a-z][a-zA-Z0-9]*$"
# -----------------------------------------------------------------------
# Custom rules: response conventions
# -----------------------------------------------------------------------
# Error responses (4xx, 5xx) should use a consistent shape
error-response-schema:
description: Error responses should reference a standard error schema
severity: hint
given: "$.paths.*.*.responses[?(@property >= '400' && @property < '600')].content['application/json'].schema"
then:
field: "$ref"
function: truthy
# All 2xx responses with JSON body should have a schema
response-schema-defined:
description: Success responses with JSON content should define a schema
severity: warn
given: "$.paths.*.*.responses[?(@property >= '200' && @property < '300')].content['application/json']"
then:
field: schema
function: truthy
# -----------------------------------------------------------------------
# Custom rules: best practices
# -----------------------------------------------------------------------
# Path parameters must have a description
path-param-description:
description: Path parameters should have a description
severity: warn
given:
- "$.paths.*.parameters[?(@.in == 'path')]"
- "$.paths.*.*.parameters[?(@.in == 'path')]"
then:
field: description
function: truthy
# Schemas should have a description
schema-description:
description: Component schemas should have a description
severity: hint
given: "$.components.schemas.*"
then:
field: description
function: truthy

View File

@ -1,7 +1,7 @@
<div align="center"> <div align="center">
# ComfyUI # ComfyUI
**The most powerful and modular visual AI engine and application.** **The most powerful and modular AI engine for content creation.**
[![Website][website-shield]][website-url] [![Website][website-shield]][website-url]
@ -31,10 +31,16 @@
[github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest [github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest
[github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases [github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases
![ComfyUI Screenshot](https://github.com/user-attachments/assets/7ccaf2c1-9b72-41ae-9a89-5688c94b7abe) <img width="1590" height="795" alt="ComfyUI Screenshot" src="https://github.com/user-attachments/assets/36e065e0-bfae-4456-8c7f-8369d5ea48a2" />
<br>
</div> </div>
ComfyUI lets you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. Available on Windows, Linux, and macOS. ComfyUI is the AI creation engine for visual professionals who demand control over every model, every parameter, and every output. Its powerful and modular node graph interface empowers creatives to generate images, videos, 3D models, audio, and more...
- ComfyUI natively supports the latest open-source state of the art models.
- API nodes provide access to the best closed source models such as Nano Banana, Seedance, Hunyuan3D, etc.
- It is available on Windows, Linux, and macOS, locally with our desktop application or on our cloud.
- The most sophisticated workflows can be exposed through a simple UI thanks to App Mode.
- It integrates seamlessly into production pipelines with our API endpoints.
## Get Started ## Get Started
@ -77,6 +83,7 @@ See what ComfyUI can do with the [newer template workflows](https://comfy.org/wo
- [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/) - [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/)
- [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/) - [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/)
- [Z Image](https://comfyanonymous.github.io/ComfyUI_examples/z_image/) - [Z Image](https://comfyanonymous.github.io/ComfyUI_examples/z_image/)
- Ernie Image
- Image Editing Models - Image Editing Models
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/) - [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model) - [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
@ -126,7 +133,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
ComfyUI follows a weekly release cycle targeting Monday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories: ComfyUI follows a weekly release cycle targeting Monday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)** 1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
- Releases a new stable version (e.g., v0.7.0) roughly every week. - Releases a new major stable version (e.g., v0.7.0) roughly every 2 weeks.
- Starting from v0.4.0 patch versions will be used for fixes backported onto the current stable release. - Starting from v0.4.0 patch versions will be used for fixes backported onto the current stable release.
- Minor versions will be used for releases off the master branch. - Minor versions will be used for releases off the master branch.
- Patch versions may still be used for releases on the master branch in cases where a backport would not make sense. - Patch versions may still be used for releases on the master branch in cases where a backport would not make sense.

View File

@ -28,8 +28,8 @@ def get_file_info(path: str, relative_to: str) -> FileInfo:
return { return {
"path": os.path.relpath(path, relative_to).replace(os.sep, '/'), "path": os.path.relpath(path, relative_to).replace(os.sep, '/'),
"size": os.path.getsize(path), "size": os.path.getsize(path),
"modified": os.path.getmtime(path), "modified": int(os.path.getmtime(path) * 1000),
"created": os.path.getctime(path) "created": int(os.path.getctime(path) * 1000),
} }

View File

@ -91,6 +91,7 @@ parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE"
parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.") parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.") parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.")
parser.add_argument("--enable-triton-backend", action="store_true", help="ComfyUI will enable the use of Triton backend in comfy-kitchen. Is disabled at launch by default.")
class LatentPreviewMethod(enum.Enum): class LatentPreviewMethod(enum.Enum):
NoPreviews = "none" NoPreviews = "none"
@ -237,6 +238,8 @@ database_default_path = os.path.abspath(
) )
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.") parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
parser.add_argument("--enable-assets", action="store_true", help="Enable the assets system (API routes, database synchronization, and background scanning).") parser.add_argument("--enable-assets", action="store_true", help="Enable the assets system (API routes, database synchronization, and background scanning).")
parser.add_argument("--feature-flag", type=str, action='append', default=[], metavar="KEY[=VALUE]", help="Set a server feature flag. Use KEY=VALUE to set an explicit value, or bare KEY to set it to true. Can be specified multiple times. Boolean values (true/false) and numbers are auto-converted. Examples: --feature-flag show_signin_button=true or --feature-flag show_signin_button")
parser.add_argument("--list-feature-flags", action="store_true", help="Print the registry of known CLI-settable feature flags as JSON and exit.")
if comfy.options.args_parsing: if comfy.options.args_parsing:
args = parser.parse_args() args = parser.parse_args()

View File

@ -63,7 +63,11 @@ class IndexListContextWindow(ContextWindowABC):
dim = self.dim dim = self.dim
if dim == 0 and full.shape[dim] == 1: if dim == 0 and full.shape[dim] == 1:
return full return full
idx = tuple([slice(None)] * dim + [self.index_list]) indices = self.index_list
anchor_idx = getattr(self, 'causal_anchor_index', None)
if anchor_idx is not None and anchor_idx >= 0:
indices = [anchor_idx] + list(indices)
idx = tuple([slice(None)] * dim + [indices])
window = full[idx] window = full[idx]
if retain_index_list: if retain_index_list:
idx = tuple([slice(None)] * dim + [retain_index_list]) idx = tuple([slice(None)] * dim + [retain_index_list])
@ -113,7 +117,14 @@ def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, d
# skip leading latent positions that have no corresponding conditioning (e.g. reference frames) # skip leading latent positions that have no corresponding conditioning (e.g. reference frames)
if temporal_offset > 0: if temporal_offset > 0:
indices = [i - temporal_offset for i in window.index_list[temporal_offset:]] anchor_idx = getattr(window, 'causal_anchor_index', None)
if anchor_idx is not None and anchor_idx >= 0:
# anchor occupies one of the no-cond positions, so skip one fewer from window.index_list
skip_count = temporal_offset - 1
else:
skip_count = temporal_offset
indices = [i - temporal_offset for i in window.index_list[skip_count:]]
indices = [i for i in indices if 0 <= i] indices = [i for i in indices if 0 <= i]
else: else:
indices = list(window.index_list) indices = list(window.index_list)
@ -150,7 +161,8 @@ class ContextFuseMethod:
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window']) ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
class IndexListContextHandler(ContextHandlerABC): class IndexListContextHandler(ContextHandlerABC):
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False): closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False,
causal_window_fix: bool=True):
self.context_schedule = context_schedule self.context_schedule = context_schedule
self.fuse_method = fuse_method self.fuse_method = fuse_method
self.context_length = context_length self.context_length = context_length
@ -162,6 +174,7 @@ class IndexListContextHandler(ContextHandlerABC):
self.freenoise = freenoise self.freenoise = freenoise
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else [] self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
self.split_conds_to_windows = split_conds_to_windows self.split_conds_to_windows = split_conds_to_windows
self.causal_window_fix = causal_window_fix
self.callbacks = {} self.callbacks = {}
@ -318,6 +331,14 @@ class IndexListContextHandler(ContextHandlerABC):
# allow processing to end between context window executions for faster Cancel # allow processing to end between context window executions for faster Cancel
comfy.model_management.throw_exception_if_processing_interrupted() comfy.model_management.throw_exception_if_processing_interrupted()
# causal_window_fix: prepend a pre-window frame that will be stripped post-forward
anchor_applied = False
if self.causal_window_fix:
anchor_idx = window.index_list[0] - 1
if 0 <= anchor_idx < x_in.size(self.dim):
window.causal_anchor_index = anchor_idx
anchor_applied = True
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks): for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device) callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device)
@ -332,6 +353,12 @@ class IndexListContextHandler(ContextHandlerABC):
if device is not None: if device is not None:
for i in range(len(sub_conds_out)): for i in range(len(sub_conds_out)):
sub_conds_out[i] = sub_conds_out[i].to(x_in.device) sub_conds_out[i] = sub_conds_out[i].to(x_in.device)
# strip causal_window_fix anchor if applied
if anchor_applied:
for i in range(len(sub_conds_out)):
sub_conds_out[i] = sub_conds_out[i].narrow(self.dim, 1, sub_conds_out[i].shape[self.dim] - 1)
results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window)) results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window))
return results return results

View File

@ -0,0 +1,34 @@
import functools
import logging
import os
logger = logging.getLogger(__name__)
_DEFAULT_DEPLOY_ENV = "local-git"
_ENV_FILENAME = ".comfy_environment"
# Resolve the ComfyUI install directory (the parent of this `comfy/` package).
# We deliberately avoid `folder_paths.base_path` here because that is overridden
# by the `--base-directory` CLI arg to a user-supplied path, whereas the
# `.comfy_environment` marker is written by launchers/installers next to the
# ComfyUI install itself.
_COMFY_INSTALL_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
@functools.cache
def get_deploy_environment() -> str:
env_file = os.path.join(_COMFY_INSTALL_DIR, _ENV_FILENAME)
try:
with open(env_file, encoding="utf-8") as f:
# Cap the read so a malformed or maliciously crafted file (e.g.
# a single huge line with no newline) can't blow up memory.
first_line = f.readline(128).strip()
value = "".join(c for c in first_line if 32 <= ord(c) < 127)
if value:
return value
except FileNotFoundError:
pass
except Exception as e:
logger.error("Failed to read %s: %s", env_file, e)
return _DEFAULT_DEPLOY_ENV

View File

@ -1810,3 +1810,102 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, simple_order_2=False): def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, simple_order_2=False):
"""Stochastic Adams Solver with PECE (PredictEvaluateCorrectEvaluate) mode (NeurIPS 2023).""" """Stochastic Adams Solver with PECE (PredictEvaluateCorrectEvaluate) mode (NeurIPS 2023)."""
return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, tau_func=tau_func, s_noise=s_noise, noise_sampler=noise_sampler, predictor_order=predictor_order, corrector_order=corrector_order, use_pece=True, simple_order_2=simple_order_2) return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, tau_func=tau_func, s_noise=s_noise, noise_sampler=noise_sampler, predictor_order=predictor_order, corrector_order=corrector_order, use_pece=True, simple_order_2=simple_order_2)
@torch.no_grad()
def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=None,
num_frame_per_block=1):
"""
Autoregressive video sampler: block-by-block denoising with KV cache
and flow-match re-noising for Causal Forcing / Self-Forcing models.
Requires a Causal-WAN compatible model (diffusion_model must expose
init_kv_caches / init_crossattn_caches) and 5-D latents [B,C,T,H,W].
All AR-loop parameters are passed via the SamplerARVideo node, not read
from the checkpoint or transformer_options.
"""
extra_args = {} if extra_args is None else extra_args
model_options = extra_args.get("model_options", {})
transformer_options = model_options.get("transformer_options", {})
if x.ndim != 5:
raise ValueError(
f"ar_video sampler requires 5-D video latents [B,C,T,H,W], got {x.ndim}-D tensor with shape {x.shape}. "
"This sampler is only compatible with autoregressive video models (e.g. Causal-WAN)."
)
inner_model = model.inner_model.inner_model
causal_model = inner_model.diffusion_model
if not (hasattr(causal_model, "init_kv_caches") and hasattr(causal_model, "init_crossattn_caches")):
raise TypeError(
"ar_video sampler requires a Causal-WAN compatible model whose diffusion_model "
"exposes init_kv_caches() and init_crossattn_caches(). The loaded checkpoint "
"does not support this interface — choose a different sampler."
)
seed = extra_args.get("seed", 0)
bs, c, lat_t, lat_h, lat_w = x.shape
frame_seq_len = -(-lat_h // 2) * -(-lat_w // 2) # ceiling division
num_blocks = -(-lat_t // num_frame_per_block) # ceiling division
device = x.device
model_dtype = inner_model.get_dtype()
kv_caches = causal_model.init_kv_caches(bs, lat_t * frame_seq_len, device, model_dtype)
crossattn_caches = causal_model.init_crossattn_caches(bs, device, model_dtype)
output = torch.zeros_like(x)
s_in = x.new_ones([x.shape[0]])
current_start_frame = 0
num_sigma_steps = len(sigmas) - 1
total_real_steps = num_blocks * num_sigma_steps
step_count = 0
try:
for block_idx in trange(num_blocks, disable=disable):
bf = min(num_frame_per_block, lat_t - current_start_frame)
fs, fe = current_start_frame, current_start_frame + bf
noisy_input = x[:, :, fs:fe]
ar_state = {
"start_frame": current_start_frame,
"kv_caches": kv_caches,
"crossattn_caches": crossattn_caches,
}
transformer_options["ar_state"] = ar_state
for i in range(num_sigma_steps):
denoised = model(noisy_input, sigmas[i] * s_in, **extra_args)
if callback is not None:
scaled_i = step_count * num_sigma_steps // total_real_steps
callback({"x": noisy_input, "i": scaled_i, "sigma": sigmas[i],
"sigma_hat": sigmas[i], "denoised": denoised})
if sigmas[i + 1] == 0:
noisy_input = denoised
else:
sigma_next = sigmas[i + 1]
torch.manual_seed(seed + block_idx * 1000 + i)
fresh_noise = torch.randn_like(denoised)
noisy_input = (1.0 - sigma_next) * denoised + sigma_next * fresh_noise
for cache in kv_caches:
cache["end"] -= bf * frame_seq_len
step_count += 1
output[:, :, fs:fe] = noisy_input
for cache in kv_caches:
cache["end"] -= bf * frame_seq_len
zero_sigma = sigmas.new_zeros([1])
_ = model(noisy_input, zero_sigma * s_in, **extra_args)
current_start_frame += bf
finally:
transformer_options.pop("ar_state", None)
return output

View File

@ -9,6 +9,7 @@ class LatentFormat:
latent_rgb_factors_reshape = None latent_rgb_factors_reshape = None
taesd_decoder_name = None taesd_decoder_name = None
spacial_downscale_ratio = 8 spacial_downscale_ratio = 8
temporal_downscale_ratio = 1
def process_in(self, latent): def process_in(self, latent):
return latent * self.scale_factor return latent * self.scale_factor
@ -235,6 +236,7 @@ class Flux2(LatentFormat):
class Mochi(LatentFormat): class Mochi(LatentFormat):
latent_channels = 12 latent_channels = 12
latent_dimensions = 3 latent_dimensions = 3
temporal_downscale_ratio = 6
def __init__(self): def __init__(self):
self.scale_factor = 1.0 self.scale_factor = 1.0
@ -278,6 +280,7 @@ class LTXV(LatentFormat):
latent_channels = 128 latent_channels = 128
latent_dimensions = 3 latent_dimensions = 3
spacial_downscale_ratio = 32 spacial_downscale_ratio = 32
temporal_downscale_ratio = 8
def __init__(self): def __init__(self):
self.latent_rgb_factors = [ self.latent_rgb_factors = [
@ -421,6 +424,7 @@ class LTXAV(LTXV):
class HunyuanVideo(LatentFormat): class HunyuanVideo(LatentFormat):
latent_channels = 16 latent_channels = 16
latent_dimensions = 3 latent_dimensions = 3
temporal_downscale_ratio = 4
scale_factor = 0.476986 scale_factor = 0.476986
latent_rgb_factors = [ latent_rgb_factors = [
[-0.0395, -0.0331, 0.0445], [-0.0395, -0.0331, 0.0445],
@ -447,6 +451,7 @@ class HunyuanVideo(LatentFormat):
class Cosmos1CV8x8x8(LatentFormat): class Cosmos1CV8x8x8(LatentFormat):
latent_channels = 16 latent_channels = 16
latent_dimensions = 3 latent_dimensions = 3
temporal_downscale_ratio = 8
latent_rgb_factors = [ latent_rgb_factors = [
[ 0.1817, 0.2284, 0.2423], [ 0.1817, 0.2284, 0.2423],
@ -472,6 +477,7 @@ class Cosmos1CV8x8x8(LatentFormat):
class Wan21(LatentFormat): class Wan21(LatentFormat):
latent_channels = 16 latent_channels = 16
latent_dimensions = 3 latent_dimensions = 3
temporal_downscale_ratio = 4
latent_rgb_factors = [ latent_rgb_factors = [
[-0.1299, -0.1692, 0.2932], [-0.1299, -0.1692, 0.2932],
@ -734,6 +740,7 @@ class HunyuanVideo15(LatentFormat):
latent_channels = 32 latent_channels = 32
latent_dimensions = 3 latent_dimensions = 3
spacial_downscale_ratio = 16 spacial_downscale_ratio = 16
temporal_downscale_ratio = 4
scale_factor = 1.03682 scale_factor = 1.03682
taesd_decoder_name = "lighttaehy1_5" taesd_decoder_name = "lighttaehy1_5"
@ -786,8 +793,27 @@ class ZImagePixelSpace(ChromaRadiance):
pass pass
class CogVideoX(LatentFormat): class CogVideoX(LatentFormat):
"""Latent format for CogVideoX-2b (THUDM/CogVideoX-2b).
scale_factor matches the vae/config.json scaling_factor for the 2b variant.
The 5b-class checkpoints (CogVideoX-5b, CogVideoX-1.5-5B, CogVideoX-Fun-V1.5-*)
use a different value; see CogVideoX1_5 below.
"""
latent_channels = 16 latent_channels = 16
latent_dimensions = 3 latent_dimensions = 3
temporal_downscale_ratio = 4
def __init__(self): def __init__(self):
self.scale_factor = 1.15258426 self.scale_factor = 1.15258426
class CogVideoX1_5(CogVideoX):
"""Latent format for 5b-class CogVideoX checkpoints.
Covers THUDM/CogVideoX-5b, THUDM/CogVideoX-1.5-5B, and the CogVideoX-Fun
V1.5-5b family (including VOID inpainting). All of these have
scaling_factor=0.7 in their vae/config.json. Auto-selected in
supported_models.CogVideoX_T2V based on transformer hidden dim.
"""
def __init__(self):
self.scale_factor = 0.7

View File

@ -16,6 +16,7 @@ from comfy.ldm.lightricks.model import (
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
import comfy.ldm.common_dit import comfy.ldm.common_dit
import comfy.model_prefetch
class CompressedTimestep: class CompressedTimestep:
"""Store video timestep embeddings in compressed form using per-frame indexing.""" """Store video timestep embeddings in compressed form using per-frame indexing."""
@ -907,9 +908,11 @@ class LTXAVModel(LTXVModel):
"""Process transformer blocks for LTXAV.""" """Process transformer blocks for LTXAV."""
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
prefetch_queue = comfy.model_prefetch.make_prefetch_queue(list(self.transformer_blocks), vx.device, transformer_options)
# Process transformer blocks # Process transformer blocks
for i, block in enumerate(self.transformer_blocks): for i, block in enumerate(self.transformer_blocks):
comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, block)
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
@ -982,6 +985,8 @@ class LTXAVModel(LTXVModel):
a_prompt_timestep=a_prompt_timestep, a_prompt_timestep=a_prompt_timestep,
) )
comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, None)
return [vx, ax] return [vx, ax]
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs): def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):

View File

@ -14,6 +14,8 @@ from .sub_quadratic_attention import efficient_dot_product_attention
from comfy import model_management from comfy import model_management
TORCH_HAS_GQA = model_management.torch_version_numeric >= (2, 5)
if model_management.xformers_enabled(): if model_management.xformers_enabled():
import xformers import xformers
import xformers.ops import xformers.ops
@ -150,7 +152,12 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
b, _, dim_head = q.shape b, _, dim_head = q.shape
dim_head //= heads dim_head //= heads
scale = dim_head ** -0.5 if kwargs.get("enable_gqa", False) and q.shape[-3] != k.shape[-3]:
n_rep = q.shape[-3] // k.shape[-3]
k = k.repeat_interleave(n_rep, dim=-3)
v = v.repeat_interleave(n_rep, dim=-3)
scale = kwargs.get("scale", dim_head ** -0.5)
h = heads h = heads
if skip_reshape: if skip_reshape:
@ -219,6 +226,10 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
b, _, dim_head = query.shape b, _, dim_head = query.shape
dim_head //= heads dim_head //= heads
if "scale" in kwargs:
# Pre-scale query to match requested scale (cancels internal 1/sqrt(dim_head))
query = query * (kwargs["scale"] * dim_head ** 0.5)
if skip_reshape: if skip_reshape:
query = query.reshape(b * heads, -1, dim_head) query = query.reshape(b * heads, -1, dim_head)
value = value.reshape(b * heads, -1, dim_head) value = value.reshape(b * heads, -1, dim_head)
@ -290,7 +301,7 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
b, _, dim_head = q.shape b, _, dim_head = q.shape
dim_head //= heads dim_head //= heads
scale = dim_head ** -0.5 scale = kwargs.get("scale", dim_head ** -0.5)
if skip_reshape: if skip_reshape:
q, k, v = map( q, k, v = map(
@ -500,8 +511,13 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
if mask.ndim == 3: if mask.ndim == 3:
mask = mask.unsqueeze(1) mask = mask.unsqueeze(1)
# Pass through extra SDPA kwargs (scale, enable_gqa) if provided
# enable_gqa requires PyTorch 2.5+; older versions use manual KV expansion above
sdpa_keys = ("scale", "enable_gqa") if TORCH_HAS_GQA else ("scale",)
sdpa_extra = {k: v for k, v in kwargs.items() if k in sdpa_keys}
if SDP_BATCH_LIMIT >= b: if SDP_BATCH_LIMIT >= b:
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False, **sdpa_extra)
if not skip_output_reshape: if not skip_output_reshape:
out = ( out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head) out.transpose(1, 2).reshape(b, -1, heads * dim_head)
@ -519,7 +535,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
k[i : i + SDP_BATCH_LIMIT], k[i : i + SDP_BATCH_LIMIT],
v[i : i + SDP_BATCH_LIMIT], v[i : i + SDP_BATCH_LIMIT],
attn_mask=m, attn_mask=m,
dropout_p=0.0, is_causal=False dropout_p=0.0, is_causal=False, **sdpa_extra
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head) ).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
return out return out

276
comfy/ldm/wan/ar_model.py Normal file
View File

@ -0,0 +1,276 @@
"""
CausalWanModel: Wan 2.1 backbone with KV-cached causal self-attention for
autoregressive (frame-by-frame) video generation via Causal Forcing.
Weight-compatible with the standard WanModel -- same layer names, same shapes.
The difference is purely in the forward pass: this model processes one temporal
block at a time and maintains a KV cache across blocks.
Reference: https://github.com/thu-ml/Causal-Forcing
"""
import torch
import torch.nn as nn
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.math import apply_rope1
from comfy.ldm.wan.model import (
sinusoidal_embedding_1d,
repeat_e,
WanModel,
WanAttentionBlock,
)
import comfy.ldm.common_dit
import comfy.model_management
class CausalWanSelfAttention(nn.Module):
"""Self-attention with KV cache support for autoregressive inference."""
def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True,
eps=1e-6, operation_settings={}):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qk_norm = qk_norm
self.eps = eps
ops = operation_settings.get("operations")
device = operation_settings.get("device")
dtype = operation_settings.get("dtype")
self.q = ops.Linear(dim, dim, device=device, dtype=dtype)
self.k = ops.Linear(dim, dim, device=device, dtype=dtype)
self.v = ops.Linear(dim, dim, device=device, dtype=dtype)
self.o = ops.Linear(dim, dim, device=device, dtype=dtype)
self.norm_q = ops.RMSNorm(dim, eps=eps, elementwise_affine=True, device=device, dtype=dtype) if qk_norm else nn.Identity()
self.norm_k = ops.RMSNorm(dim, eps=eps, elementwise_affine=True, device=device, dtype=dtype) if qk_norm else nn.Identity()
def forward(self, x, freqs, kv_cache=None, transformer_options={}):
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
q = apply_rope1(self.norm_q(self.q(x)).view(b, s, n, d), freqs)
k = apply_rope1(self.norm_k(self.k(x)).view(b, s, n, d), freqs)
v = self.v(x).view(b, s, n, d)
if kv_cache is None:
x = optimized_attention(
q.view(b, s, n * d),
k.view(b, s, n * d),
v.view(b, s, n * d),
heads=self.num_heads,
transformer_options=transformer_options,
)
else:
end = kv_cache["end"]
new_end = end + s
# Roped K and plain V go into cache
kv_cache["k"][:, end:new_end] = k
kv_cache["v"][:, end:new_end] = v
kv_cache["end"] = new_end
x = optimized_attention(
q.view(b, s, n * d),
kv_cache["k"][:, :new_end].view(b, new_end, n * d),
kv_cache["v"][:, :new_end].view(b, new_end, n * d),
heads=self.num_heads,
transformer_options=transformer_options,
)
x = self.o(x)
return x
class CausalWanAttentionBlock(WanAttentionBlock):
"""Transformer block with KV-cached self-attention and cross-attention caching."""
def __init__(self, cross_attn_type, dim, ffn_dim, num_heads,
window_size=(-1, -1), qk_norm=True, cross_attn_norm=False,
eps=1e-6, operation_settings={}):
super().__init__(cross_attn_type, dim, ffn_dim, num_heads,
window_size, qk_norm, cross_attn_norm, eps,
operation_settings=operation_settings)
self.self_attn = CausalWanSelfAttention(
dim, num_heads, window_size, qk_norm, eps,
operation_settings=operation_settings)
def forward(self, x, e, freqs, context, context_img_len=257,
kv_cache=None, crossattn_cache=None, transformer_options={}):
if e.ndim < 4:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
else:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2)
# Self-attention with optional KV cache
x = x.contiguous()
y = self.self_attn(
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
freqs, kv_cache=kv_cache, transformer_options=transformer_options)
x = torch.addcmul(x, y, repeat_e(e[2], x))
del y
# Cross-attention with optional caching
if crossattn_cache is not None and crossattn_cache.get("is_init"):
q = self.cross_attn.norm_q(self.cross_attn.q(self.norm3(x)))
x_ca = optimized_attention(
q, crossattn_cache["k"], crossattn_cache["v"],
heads=self.num_heads, transformer_options=transformer_options)
x = x + self.cross_attn.o(x_ca)
else:
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
if crossattn_cache is not None:
crossattn_cache["k"] = self.cross_attn.norm_k(self.cross_attn.k(context))
crossattn_cache["v"] = self.cross_attn.v(context)
crossattn_cache["is_init"] = True
# FFN
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
x = torch.addcmul(x, y, repeat_e(e[5], x))
return x
class CausalWanModel(WanModel):
"""
Wan 2.1 diffusion backbone with causal KV-cache support.
Same weight structure as WanModel -- loads identical state dicts.
Adds forward_block() for frame-by-frame autoregressive inference.
"""
def __init__(self,
model_type='t2v',
patch_size=(1, 2, 2),
text_len=512,
in_dim=16,
dim=2048,
ffn_dim=8192,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=16,
num_layers=32,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
image_model=None,
device=None,
dtype=None,
operations=None):
super().__init__(
model_type=model_type, patch_size=patch_size, text_len=text_len,
in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim,
text_dim=text_dim, out_dim=out_dim, num_heads=num_heads,
num_layers=num_layers, window_size=window_size, qk_norm=qk_norm,
cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model,
wan_attn_block_class=CausalWanAttentionBlock,
device=device, dtype=dtype, operations=operations)
def forward_block(self, x, timestep, context, start_frame,
kv_caches, crossattn_caches, clip_fea=None):
"""
Forward one temporal block for autoregressive inference.
Args:
x: [B, C, block_frames, H, W] input latent for the current block
timestep: [B, block_frames] per-frame timesteps
context: [B, L, text_dim] raw text embeddings (pre-text_embedding)
start_frame: temporal frame index for RoPE offset
kv_caches: list of per-layer KV cache dicts
crossattn_caches: list of per-layer cross-attention cache dicts
clip_fea: optional CLIP features for I2V
Returns:
flow_pred: [B, C_out, block_frames, H, W] flow prediction
"""
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
bs, c, t, h, w = x.shape
x = self.patch_embedding(x.float()).to(x.dtype)
grid_sizes = x.shape[2:]
x = x.flatten(2).transpose(1, 2)
# Per-frame time embedding
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()).to(dtype=x.dtype))
e = e.reshape(timestep.shape[0], -1, e.shape[-1])
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
# Text embedding (reuses crossattn_cache after first block)
context = self.text_embedding(context)
context_img_len = None
if clip_fea is not None and self.img_emb is not None:
context_clip = self.img_emb(clip_fea)
context = torch.concat([context_clip, context], dim=1)
context_img_len = clip_fea.shape[-2]
# RoPE for current block's temporal position
freqs = self.rope_encode(t, h, w, t_start=start_frame, device=x.device, dtype=x.dtype)
# Transformer blocks
for i, block in enumerate(self.blocks):
x = block(x, e=e0, freqs=freqs, context=context,
context_img_len=context_img_len,
kv_cache=kv_caches[i],
crossattn_cache=crossattn_caches[i])
# Head
x = self.head(x, e)
# Unpatchify
x = self.unpatchify(x, grid_sizes)
return x[:, :, :t, :h, :w]
def init_kv_caches(self, batch_size, max_seq_len, device, dtype):
"""Create fresh KV caches for all layers."""
caches = []
for _ in range(self.num_layers):
caches.append({
"k": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype),
"v": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype),
"end": 0,
})
return caches
def init_crossattn_caches(self, batch_size, device, dtype):
"""Create fresh cross-attention caches for all layers."""
caches = []
for _ in range(self.num_layers):
caches.append({"is_init": False})
return caches
def reset_kv_caches(self, kv_caches):
"""Reset KV caches to empty (reuse allocated memory)."""
for cache in kv_caches:
cache["end"] = 0
def reset_crossattn_caches(self, crossattn_caches):
"""Reset cross-attention caches."""
for cache in crossattn_caches:
cache["is_init"] = False
@property
def head_dim(self):
return self.dim // self.num_heads
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
ar_state = transformer_options.get("ar_state")
if ar_state is not None:
bs = x.shape[0]
block_frames = x.shape[2]
t_per_frame = timestep.unsqueeze(1).expand(bs, block_frames)
return self.forward_block(
x=x, timestep=t_per_frame, context=context,
start_frame=ar_state["start_frame"],
kv_caches=ar_state["kv_caches"],
crossattn_caches=ar_state["crossattn_caches"],
clip_fea=clip_fea,
)
return super().forward(x, timestep, context, clip_fea=clip_fea,
time_dim_concat=time_dim_concat,
transformer_options=transformer_options, **kwargs)

View File

@ -17,6 +17,7 @@
""" """
from __future__ import annotations from __future__ import annotations
import comfy.memory_management
import comfy.utils import comfy.utils
import comfy.model_management import comfy.model_management
import comfy.model_base import comfy.model_base
@ -473,3 +474,17 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
weight = old_weight weight = old_weight
return weight return weight
def prefetch_prepared_value(value, allocate_buffer, stream):
if isinstance(value, torch.Tensor):
dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value))
comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream)
return comfy.memory_management.interpret_gathered_like([value], dest)[0]
elif isinstance(value, weight_adapter.WeightAdapterBase):
return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream))
elif isinstance(value, tuple):
return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value)
elif isinstance(value, list):
return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value]
return value

View File

@ -42,6 +42,7 @@ import comfy.ldm.cosmos.predict2
import comfy.ldm.lumina.model import comfy.ldm.lumina.model
import comfy.ldm.wan.model import comfy.ldm.wan.model
import comfy.ldm.wan.model_animate import comfy.ldm.wan.model_animate
import comfy.ldm.wan.ar_model
import comfy.ldm.hunyuan3d.model import comfy.ldm.hunyuan3d.model
import comfy.ldm.hidream.model import comfy.ldm.hidream.model
import comfy.ldm.chroma.model import comfy.ldm.chroma.model
@ -214,6 +215,11 @@ class BaseModel(torch.nn.Module):
if "latent_shapes" in extra_conds: if "latent_shapes" in extra_conds:
xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes")) xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes"))
transformer_options = transformer_options.copy()
transformer_options["prefetch_dynamic_vbars"] = (
self.current_patcher is not None and self.current_patcher.is_dynamic()
)
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds) model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
if len(model_output) > 1 and not torch.is_tensor(model_output): if len(model_output) > 1 and not torch.is_tensor(model_output):
model_output, _ = utils.pack_latents(model_output) model_output, _ = utils.pack_latents(model_output)
@ -1360,6 +1366,13 @@ class WAN21(BaseModel):
return out return out
class WAN21_CausalAR(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device,
unet_model=comfy.ldm.wan.ar_model.CausalWanModel)
self.image_to_video = False
class WAN21_Vace(WAN21): class WAN21_Vace(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.VaceWanModel) super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.VaceWanModel)

View File

@ -31,6 +31,7 @@ from contextlib import nullcontext
import comfy.memory_management import comfy.memory_management
import comfy.utils import comfy.utils
import comfy.quant_ops import comfy.quant_ops
import comfy_aimdo.vram_buffer
class VRAMState(Enum): class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram DISABLED = 0 #No vram present: no need to move models to vram
@ -720,13 +721,15 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
else: else:
minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory()) minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory())
models_temp = set() # Order-preserving dedup. A plain set() would randomize iteration order across runs
models_temp = {}
for m in models: for m in models:
models_temp.add(m) models_temp[m] = None
for mm in m.model_patches_models(): for mm in m.model_patches_models():
models_temp.add(mm) models_temp[mm] = None
models = models_temp models = list(models_temp)
models.reverse()
models_to_load = [] models_to_load = []
@ -1175,6 +1178,10 @@ stream_counters = {}
STREAM_CAST_BUFFERS = {} STREAM_CAST_BUFFERS = {}
LARGEST_CASTED_WEIGHT = (None, 0) LARGEST_CASTED_WEIGHT = (None, 0)
STREAM_AIMDO_CAST_BUFFERS = {}
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3
def get_cast_buffer(offload_stream, device, size, ref): def get_cast_buffer(offload_stream, device, size, ref):
global LARGEST_CASTED_WEIGHT global LARGEST_CASTED_WEIGHT
@ -1208,13 +1215,26 @@ def get_cast_buffer(offload_stream, device, size, ref):
return cast_buffer return cast_buffer
def get_aimdo_cast_buffer(offload_stream, device):
cast_buffer = STREAM_AIMDO_CAST_BUFFERS.get(offload_stream, None)
if cast_buffer is None:
cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index)
STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer
return cast_buffer
def reset_cast_buffers(): def reset_cast_buffers():
global LARGEST_CASTED_WEIGHT global LARGEST_CASTED_WEIGHT
global LARGEST_AIMDO_CASTED_WEIGHT
LARGEST_CASTED_WEIGHT = (None, 0) LARGEST_CASTED_WEIGHT = (None, 0)
for offload_stream in STREAM_CAST_BUFFERS: LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
offload_stream.synchronize() for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS):
if offload_stream is not None:
offload_stream.synchronize()
synchronize() synchronize()
STREAM_CAST_BUFFERS.clear() STREAM_CAST_BUFFERS.clear()
STREAM_AIMDO_CAST_BUFFERS.clear()
soft_empty_cache() soft_empty_cache()
def get_offload_stream(device): def get_offload_stream(device):

View File

@ -121,9 +121,20 @@ class LowVramPatch:
self.patches = patches self.patches = patches
self.convert_func = convert_func # TODO: remove self.convert_func = convert_func # TODO: remove
self.set_func = set_func self.set_func = set_func
self.prepared_patches = None
def prepare(self, allocate_buffer, stream):
self.prepared_patches = [
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], allocate_buffer, stream), patch[2], patch[3], patch[4])
for patch in self.patches[self.key]
]
def clear_prepared(self):
self.prepared_patches = None
def __call__(self, weight): def __call__(self, weight):
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype) patches = self.prepared_patches if self.prepared_patches is not None else self.patches[self.key]
return comfy.lora.calculate_weight(patches, weight, self.key, intermediate_dtype=weight.dtype)
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2 LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2

66
comfy/model_prefetch.py Normal file
View File

@ -0,0 +1,66 @@
import comfy_aimdo.model_vbar
import comfy.model_management
import comfy.ops
PREFETCH_QUEUES = []
def cleanup_prefetched_modules(comfy_modules):
for s in comfy_modules:
prefetch = getattr(s, "_prefetch", None)
if prefetch is None:
continue
for param_key in ("weight", "bias"):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
if lowvram_fn is not None:
lowvram_fn.clear_prepared()
if prefetch["signature"] is not None:
comfy_aimdo.model_vbar.vbar_unpin(s._v)
delattr(s, "_prefetch")
def cleanup_prefetch_queues():
global PREFETCH_QUEUES
for queue in PREFETCH_QUEUES:
for entry in queue:
if entry is None or not isinstance(entry, tuple):
continue
_, prefetch_state = entry
comfy_modules = prefetch_state[1]
if comfy_modules is not None:
cleanup_prefetched_modules(comfy_modules)
PREFETCH_QUEUES = []
def prefetch_queue_pop(queue, device, module):
if queue is None:
return
consumed = queue.pop(0)
if consumed is not None:
offload_stream, prefetch_state = consumed
if offload_stream is not None:
offload_stream.wait_stream(comfy.model_management.current_stream(device))
_, comfy_modules = prefetch_state
if comfy_modules is not None:
cleanup_prefetched_modules(comfy_modules)
prefetch = queue[0]
if prefetch is not None:
comfy_modules = []
for s in prefetch.modules():
if hasattr(s, "_v"):
comfy_modules.append(s)
offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True)
comfy.model_management.sync_stream(device, offload_stream)
queue[0] = (offload_stream, (prefetch, comfy_modules))
def make_prefetch_queue(queue, device, transformer_options):
if (not transformer_options.get("prefetch_dynamic_vbars", False)
or comfy.model_management.NUM_STREAMS == 0
or comfy.model_management.is_device_cpu(device)
or not comfy.model_management.device_supports_non_blocking(device)):
return None
queue = [None] + queue + [None]
PREFETCH_QUEUES.append(queue)
return queue

View File

@ -86,38 +86,61 @@ def materialize_meta_param(s, param_keys):
setattr(s, param_key, torch.nn.Parameter(torch.zeros(param.shape, dtype=param.dtype), requires_grad=param.requires_grad)) setattr(s, param_key, torch.nn.Parameter(torch.zeros(param.shape, dtype=param.dtype), requires_grad=param.requires_grad))
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant): # FIXME: add n=1 cache hit fast path
#vbar doesn't support CPU weights, but some custom nodes have weird paths def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blocking):
#that might switch the layer to the CPU and expect it to work. We have to take
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
#If you are a custom node author reading this, please move your layer to the GPU
#or declare your ModelPatcher as CPU in the first place.
if comfy.model_management.is_device_cpu(device):
materialize_meta_param(s, ["weight", "bias"])
weight = s.weight.to(dtype=dtype, copy=True)
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
bias = None
if s.bias is not None:
bias = s.bias.to(dtype=bias_dtype, copy=True)
return weight, bias, (None, None, None)
offload_stream = None offload_stream = None
xfer_dest = None cast_buffer = None
cast_buffer_offset = 0
def ensure_offload_stream(module, required_size, check_largest):
nonlocal offload_stream
nonlocal cast_buffer
if offload_stream is None:
offload_stream = comfy.model_management.get_offload_stream(device)
if offload_stream is None or not check_largest or len(comfy_modules) != 1:
return
current_size = 0 if cast_buffer is None else cast_buffer.size()
if current_size < required_size and module is comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[0]:
offload_stream = comfy.model_management.get_offload_stream(device)
cast_buffer = None
if required_size > comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[1]:
comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT = (module, required_size)
def get_cast_buffer(buffer_size):
nonlocal offload_stream
nonlocal cast_buffer
nonlocal cast_buffer_offset
if buffer_size == 0:
return None
if offload_stream is None:
return torch.empty((buffer_size,), dtype=torch.uint8, device=device)
cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device)
buffer = comfy_aimdo.torch.aimdo_to_tensor(cast_buffer.get(buffer_size, cast_buffer_offset), device)
cast_buffer_offset += buffer_size
return buffer
for s in comfy_modules:
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
prefetch = {
"signature": signature,
"resident": resident,
}
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
if signature is not None:
if resident: if resident:
weight = s._v_weight s._prefetch = prefetch
bias = s._v_bias continue
else:
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
if not resident:
materialize_meta_param(s, ["weight", "bias"]) materialize_meta_param(s, ["weight", "bias"])
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) if signature is not None else None
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ]) cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
cast_dest = None cast_dest = None
needs_cast = False
xfer_source = [ s.weight, s.bias ] xfer_source = [ s.weight, s.bias ]
@ -129,22 +152,15 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
if data is None: if data is None:
continue continue
if data.dtype != geometry.dtype: if data.dtype != geometry.dtype:
needs_cast = True
cast_dest = xfer_dest cast_dest = xfer_dest
if cast_dest is None:
cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device)
xfer_dest = None xfer_dest = None
break break
dest_size = comfy.memory_management.vram_aligned_size(xfer_source) dest_size = comfy.memory_management.vram_aligned_size(xfer_source)
offload_stream = comfy.model_management.get_offload_stream(device) ensure_offload_stream(s, dest_size if xfer_dest is None else 0, True)
if xfer_dest is None and offload_stream is not None:
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
if xfer_dest is None:
offload_stream = comfy.model_management.get_offload_stream(device)
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
if xfer_dest is None: if xfer_dest is None:
xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device) xfer_dest = get_cast_buffer(dest_size)
offload_stream = None
if signature is None and pin is None: if signature is None and pin is None:
comfy.pinned_memory.pin_memory(s) comfy.pinned_memory.pin_memory(s)
@ -157,27 +173,54 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
xfer_source = [ pin ] xfer_source = [ pin ]
#send it over #send it over
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream) comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
comfy.model_management.sync_stream(device, offload_stream)
if cast_dest is not None: for param_key in ("weight", "bias"):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
if lowvram_fn is not None:
ensure_offload_stream(s, cast_buffer_offset, False)
lowvram_fn.prepare(lambda size: get_cast_buffer(size), offload_stream)
prefetch["xfer_dest"] = xfer_dest
prefetch["cast_dest"] = cast_dest
prefetch["cast_geometry"] = cast_geometry
prefetch["needs_cast"] = needs_cast
s._prefetch = prefetch
return offload_stream
def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant):
prefetch = getattr(s, "_prefetch", None)
if prefetch["resident"]:
weight = s._v_weight
bias = s._v_bias
else:
xfer_dest = prefetch["xfer_dest"]
if prefetch["needs_cast"]:
cast_dest = prefetch["cast_dest"] if prefetch["cast_dest"] is not None else torch.empty((comfy.memory_management.vram_aligned_size(prefetch["cast_geometry"]),), dtype=torch.uint8, device=device)
for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest), for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest),
comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)): comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], cast_dest)):
if post_cast is not None: if post_cast is not None:
post_cast.copy_(pre_cast) post_cast.copy_(pre_cast)
xfer_dest = cast_dest xfer_dest = cast_dest
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest) params = comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], xfer_dest)
weight = params[0] weight = params[0]
bias = params[1] bias = params[1]
if signature is not None: if prefetch["signature"] is not None:
s._v_weight = weight s._v_weight = weight
s._v_bias = bias s._v_bias = bias
s._v_signature=signature s._v_signature = prefetch["signature"]
def post_cast(s, param_key, x, dtype, resident, update_weight): def post_cast(s, param_key, x, dtype, resident, update_weight):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None) lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
fns = getattr(s, param_key + "_function", []) fns = getattr(s, param_key + "_function", [])
if x is None:
return None
orig = x orig = x
def to_dequant(tensor, dtype): def to_dequant(tensor, dtype):
@ -205,14 +248,15 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
x = f(x) x = f(x)
return x return x
update_weight = signature is not None update_weight = prefetch["signature"] is not None
weight = post_cast(s, "weight", weight, dtype, prefetch["resident"], update_weight)
if bias is not None:
bias = post_cast(s, "bias", bias, bias_dtype, prefetch["resident"], update_weight)
weight = post_cast(s, "weight", weight, dtype, resident, update_weight) if prefetch["signature"] is not None:
if s.bias is not None: prefetch["resident"] = True
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
#FIXME: weird offload return protocol return weight, bias
return weight, bias, (offload_stream, device if signature is not None else None, None)
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False): def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
@ -230,10 +274,46 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if device is None: if device is None:
device = input.device device = input.device
def format_return(result, offloadable):
weight, bias, offload_stream = result
return (weight, bias, offload_stream) if offloadable else (weight, bias)
non_blocking = comfy.model_management.device_supports_non_blocking(device) non_blocking = comfy.model_management.device_supports_non_blocking(device)
if hasattr(s, "_v"): if hasattr(s, "_v"):
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
#vbar doesn't support CPU weights, but some custom nodes have weird paths
#that might switch the layer to the CPU and expect it to work. We have to take
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
#If you are a custom node author reading this, please move your layer to the GPU
#or declare your ModelPatcher as CPU in the first place.
if comfy.model_management.is_device_cpu(device):
materialize_meta_param(s, ["weight", "bias"])
weight = s.weight.to(dtype=dtype, copy=True)
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
bias = s.bias.to(dtype=bias_dtype, copy=True) if s.bias is not None else None
return format_return((weight, bias, (None, None, None)), offloadable)
prefetched = hasattr(s, "_prefetch")
offload_stream = None
offload_device = None
if not prefetched:
offload_stream = cast_modules_with_vbar([s], dtype, device, bias_dtype, non_blocking)
comfy.model_management.sync_stream(device, offload_stream)
weight, bias = resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant)
if not prefetched:
if getattr(s, "_prefetch")["signature"] is not None:
offload_device = device
for param_key in ("weight", "bias"):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
if lowvram_fn is not None:
lowvram_fn.clear_prepared()
delattr(s, "_prefetch")
return format_return((weight, bias, (offload_stream, offload_device, None)), offloadable)
if offloadable and (device != s.weight.device or if offloadable and (device != s.weight.device or
(s.bias is not None and device != s.bias.device)): (s.bias is not None and device != s.bias.device)):
@ -280,11 +360,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
for f in s.weight_function: for f in s.weight_function:
weight = f(weight) weight = f(weight)
if offloadable: return format_return((weight, bias, (offload_stream, weight_a, bias_a)), offloadable)
return weight, bias, (offload_stream, weight_a, bias_a)
else:
#Legacy function signature
return weight, bias
def uncast_bias_weight(s, weight, bias, offload_stream): def uncast_bias_weight(s, weight, bias, offload_stream):
@ -1173,6 +1249,93 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
self._buffers[key] = fn(buf) self._buffers[key] = fn(buf)
return self return self
class Embedding(manual_cast.Embedding):
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
weight_key = f"{prefix}weight"
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
if layer_conf is not None:
layer_conf = json.loads(layer_conf.numpy().tobytes())
# Only fp8 makes sense for embeddings (per-row dequant via index select).
# Block-scaled formats (NVFP4, MXFP8) can't do per-row lookup efficiently.
quant_format = layer_conf.get("format", None) if layer_conf is not None else None
if quant_format in ["float8_e4m3fn", "float8_e5m2"] and weight_key in state_dict:
self.quant_format = quant_format
qconfig = QUANT_ALGOS[quant_format]
layout_cls = get_layout_class(qconfig["comfy_tensor_layout"])
weight = state_dict.pop(weight_key)
manually_loaded_keys = [weight_key]
scale_key = f"{prefix}weight_scale"
scale = state_dict.pop(scale_key, None)
if scale is not None:
scale = scale.float()
manually_loaded_keys.append(scale_key)
params = layout_cls.Params(
scale=scale if scale is not None else torch.ones((), dtype=torch.float32),
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=(self.num_embeddings, self.embedding_dim),
)
self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(dtype=qconfig["storage_t"]), qconfig["comfy_tensor_layout"], params),
requires_grad=False)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
for k in manually_loaded_keys:
if k in missing_keys:
missing_keys.remove(k)
else:
if layer_conf is not None:
state_dict[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
def state_dict(self, *args, destination=None, prefix="", **kwargs):
if destination is not None:
sd = destination
else:
sd = {}
if not hasattr(self, 'weight') or self.weight is None:
return sd
if isinstance(self.weight, QuantizedTensor):
sd_out = self.weight.state_dict("{}weight".format(prefix))
for k in sd_out:
sd[k] = sd_out[k]
quant_conf = {"format": self.quant_format}
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
else:
sd["{}weight".format(prefix)] = self.weight
return sd
def forward_comfy_cast_weights(self, input, out_dtype=None):
weight = self.weight
# Optimized path: lookup in fp8, dequantize only the selected rows.
if isinstance(weight, QuantizedTensor) and len(self.weight_function) == 0:
qdata, _, offload_stream = cast_bias_weight(self, device=input.device, dtype=weight.dtype, offloadable=True)
if isinstance(qdata, QuantizedTensor):
scale = qdata._params.scale
qdata = qdata._qdata
else:
scale = None
x = torch.nn.functional.embedding(
input, qdata, self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.sparse)
uncast_bias_weight(self, qdata, None, offload_stream)
target_dtype = out_dtype if out_dtype is not None else weight._params.orig_dtype
x = x.to(dtype=target_dtype)
if scale is not None and scale != 1.0:
x = x * scale.to(dtype=target_dtype)
return x
# Fallback for non-quantized or weight_function (LoRA) case
return super().forward_comfy_cast_weights(input, out_dtype=out_dtype)
return MixedPrecisionOps return MixedPrecisionOps
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None): def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):

View File

@ -1,6 +1,8 @@
import torch import torch
import logging import logging
from comfy.cli_args import args
try: try:
import comfy_kitchen as ck import comfy_kitchen as ck
from comfy_kitchen.tensor import ( from comfy_kitchen.tensor import (
@ -21,7 +23,15 @@ try:
ck.registry.disable("cuda") ck.registry.disable("cuda")
logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.") logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.")
ck.registry.disable("triton") if args.enable_triton_backend:
try:
import triton
logging.info("Found triton %s. Enabling comfy-kitchen triton backend.", triton.__version__)
except ImportError as e:
logging.error(f"Failed to import triton, Error: {e}, the comfy-kitchen triton backend will not be available.")
ck.registry.disable("triton")
else:
ck.registry.disable("triton")
for k, v in ck.list_backends().items(): for k, v in ck.list_backends().items():
logging.info(f"Found comfy_kitchen backend {k}: {v}") logging.info(f"Found comfy_kitchen backend {k}: {v}")
except ImportError as e: except ImportError as e:

View File

@ -3,6 +3,7 @@ import comfy.model_management
RMSNorm = torch.nn.RMSNorm RMSNorm = torch.nn.RMSNorm
# Note: torch's fused F.rms_norm is faster but produces slightly different output than manual implementations (rsqrt/reduction rounding).
def rms_norm(x, weight=None, eps=1e-6): def rms_norm(x, weight=None, eps=1e-6):
if weight is None: if weight is None:
return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps) return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps)

View File

@ -89,7 +89,8 @@ def get_additional_models(conds, dtype):
gligen += get_models_from_cond(conds[k], "gligen") gligen += get_models_from_cond(conds[k], "gligen")
add_models += get_models_from_cond(conds[k], "additional_models") add_models += get_models_from_cond(conds[k], "additional_models")
control_nets = set(cnets) # Order-preserving dedup. A plain set() would randomize iteration order across runs
control_nets = list(dict.fromkeys(cnets))
inference_memory = 0 inference_memory = 0
control_models = [] control_models = []

View File

@ -65,6 +65,8 @@ import comfy.text_encoders.ace15
import comfy.text_encoders.longcat_image import comfy.text_encoders.longcat_image
import comfy.text_encoders.qwen35 import comfy.text_encoders.qwen35
import comfy.text_encoders.ernie import comfy.text_encoders.ernie
import comfy.text_encoders.gemma4
import comfy.text_encoders.cogvideo
import comfy.model_patcher import comfy.model_patcher
import comfy.lora import comfy.lora
@ -1223,6 +1225,7 @@ class CLIPType(Enum):
NEWBIE = 24 NEWBIE = 24
FLUX2 = 25 FLUX2 = 25
LONGCAT_IMAGE = 26 LONGCAT_IMAGE = 26
COGVIDEOX = 27
@ -1271,6 +1274,9 @@ class TEModel(Enum):
QWEN35_9B = 26 QWEN35_9B = 26
QWEN35_27B = 27 QWEN35_27B = 27
MINISTRAL_3_3B = 28 MINISTRAL_3_3B = 28
GEMMA_4_E4B = 29
GEMMA_4_E2B = 30
GEMMA_4_31B = 31
def detect_te_model(sd): def detect_te_model(sd):
@ -1296,6 +1302,12 @@ def detect_te_model(sd):
return TEModel.BYT5_SMALL_GLYPH return TEModel.BYT5_SMALL_GLYPH
return TEModel.T5_BASE return TEModel.T5_BASE
if 'model.layers.0.post_feedforward_layernorm.weight' in sd: if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
if 'model.layers.59.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_4_31B
if 'model.layers.41.self_attn.q_norm.weight' in sd and 'model.layers.47.self_attn.q_norm.weight' not in sd:
return TEModel.GEMMA_4_E4B
if 'model.layers.34.self_attn.q_norm.weight' in sd and 'model.layers.41.self_attn.q_norm.weight' not in sd:
return TEModel.GEMMA_4_E2B
if 'model.layers.47.self_attn.q_norm.weight' in sd: if 'model.layers.47.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_3_12B return TEModel.GEMMA_3_12B
if 'model.layers.0.self_attn.q_norm.weight' in sd: if 'model.layers.0.self_attn.q_norm.weight' in sd:
@ -1418,6 +1430,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data), clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data),
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None) clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
elif clip_type == CLIPType.COGVIDEOX:
clip_target.clip = comfy.text_encoders.cogvideo.cogvideo_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.cogvideo.CogVideoXTokenizer
else: #CLIPType.MOCHI else: #CLIPType.MOCHI
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data)) clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
@ -1435,6 +1450,13 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
else: else:
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
elif te_model in (TEModel.GEMMA_4_E4B, TEModel.GEMMA_4_E2B, TEModel.GEMMA_4_31B):
variant = {TEModel.GEMMA_4_E4B: comfy.text_encoders.gemma4.Gemma4_E4B,
TEModel.GEMMA_4_E2B: comfy.text_encoders.gemma4.Gemma4_E2B,
TEModel.GEMMA_4_31B: comfy.text_encoders.gemma4.Gemma4_31B}[te_model]
clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data), model_class=variant)
clip_target.tokenizer = variant.tokenizer
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
elif te_model == TEModel.GEMMA_2_2B: elif te_model == TEModel.GEMMA_2_2B:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data)) clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer

View File

@ -1167,6 +1167,25 @@ class WAN21_T2V(supported_models_base.BASE):
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}umt5xxl.transformer.".format(pref)) t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}umt5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.wan.WanT5Tokenizer, comfy.text_encoders.wan.te(**t5_detect)) return supported_models_base.ClipTarget(comfy.text_encoders.wan.WanT5Tokenizer, comfy.text_encoders.wan.te(**t5_detect))
class WAN21_CausalAR_T2V(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "t2v",
"causal_ar": True,
}
sampling_settings = {
"shift": 5.0,
}
def __init__(self, unet_config):
super().__init__(unet_config)
self.unet_config.pop("causal_ar", None)
def get_model(self, state_dict, prefix="", device=None):
return model_base.WAN21_CausalAR(self, device=device)
class WAN21_I2V(WAN21_T2V): class WAN21_I2V(WAN21_T2V):
unet_config = { unet_config = {
"image_model": "wan2.1", "image_model": "wan2.1",
@ -1853,6 +1872,14 @@ class CogVideoX_T2V(supported_models_base.BASE):
vae_key_prefix = ["vae."] vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."] text_encoder_key_prefix = ["text_encoders."]
def __init__(self, unet_config):
# 2b-class (dim=1920, heads=30) uses scale_factor=1.15258426.
# 5b-class (dim=3072, heads=48) — incl. CogVideoX-5b, 1.5-5B, and
# Fun-V1.5 inpainting — uses scale_factor=0.7 per vae/config.json.
if unet_config.get("num_attention_heads", 0) >= 48:
self.latent_format = latent_formats.CogVideoX1_5
super().__init__(unet_config)
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
# CogVideoX 1.5 (patch_size_t=2) has different training base dimensions for RoPE # CogVideoX 1.5 (patch_size_t=2) has different training base dimensions for RoPE
if self.unet_config.get("patch_size_t") is not None: if self.unet_config.get("patch_size_t") is not None:
@ -1879,6 +1906,20 @@ class CogVideoX_I2V(CogVideoX_T2V):
out = model_base.CogVideoX(self, image_to_video=True, device=device) out = model_base.CogVideoX(self, image_to_video=True, device=device)
return out return out
class CogVideoX_Inpaint(CogVideoX_T2V):
unet_config = {
"image_model": "cogvideox",
"in_channels": 48,
}
def get_model(self, state_dict, prefix="", device=None):
if self.unet_config.get("patch_size_t") is not None:
self.unet_config.setdefault("sample_height", 96)
self.unet_config.setdefault("sample_width", 170)
self.unet_config.setdefault("sample_frames", 81)
out = model_base.CogVideoX(self, image_to_video=True, device=device)
return out
models = [ models = [
LotusD, LotusD,
@ -1929,6 +1970,7 @@ models = [
ZImage, ZImage,
Lumina2, Lumina2,
WAN22_T2V, WAN22_T2V,
WAN21_CausalAR_T2V,
WAN21_T2V, WAN21_T2V,
WAN21_I2V, WAN21_I2V,
WAN21_FunControl2V, WAN21_FunControl2V,
@ -1958,6 +2000,7 @@ models = [
ErnieImage, ErnieImage,
SAM3, SAM3,
SAM31, SAM31,
CogVideoX_Inpaint,
CogVideoX_I2V, CogVideoX_I2V,
CogVideoX_T2V, CogVideoX_T2V,
SVD_img2vid, SVD_img2vid,

View File

@ -1,6 +1,48 @@
import comfy.text_encoders.sd3_clip import comfy.text_encoders.sd3_clip
from comfy import sd1_clip
class CogVideoXT5Tokenizer(comfy.text_encoders.sd3_clip.T5XXLTokenizer): class CogVideoXT5Tokenizer(comfy.text_encoders.sd3_clip.T5XXLTokenizer):
"""Inner T5 tokenizer for CogVideoX.
CogVideoX was trained with T5 embeddings padded to 226 tokens (not 77 like SD3).
Used both directly by supported_models.CogVideoX_T2V.clip_target (paired with
the raw T5XXLModel) and by the CogVideoXTokenizer outer wrapper below.
"""
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, min_length=226) super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, min_length=226)
class CogVideoXTokenizer(sd1_clip.SD1Tokenizer):
"""Outer tokenizer wrapper for CLIPLoader (type="cogvideox")."""
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data,
clip_name="t5xxl", tokenizer=CogVideoXT5Tokenizer)
class CogVideoXT5XXL(sd1_clip.SD1ClipModel):
"""Outer T5XXL model wrapper for CLIPLoader (type="cogvideox").
Wraps the raw T5XXL model in the SD1ClipModel interface so that CLIP.__init__
(which reads self.dtypes) works correctly. The inner model is the standard
sd3_clip.T5XXLModel (no attention_mask change needed for CogVideoX).
"""
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="t5xxl",
clip_model=comfy.text_encoders.sd3_clip.T5XXLModel,
model_options=model_options)
def cogvideo_te(dtype_t5=None, t5_quantization_metadata=None):
"""Factory that returns a CogVideoXT5XXL class configured with the detected
T5 dtype and optional quantization metadata, for use in load_text_encoder_state_dicts.
"""
class CogVideoXTEModel_(CogVideoXT5XXL):
def __init__(self, device="cpu", dtype=None, model_options={}):
if t5_quantization_metadata is not None:
model_options = model_options.copy()
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
if dtype_t5 is not None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)
return CogVideoXTEModel_

File diff suppressed because it is too large Load Diff

View File

@ -521,7 +521,7 @@ class Attention(nn.Module):
else: else:
present_key_value = (xk, xv, index + num_tokens) present_key_value = (xk, xv, index + num_tokens)
if sliding_window is not None and xk.shape[2] > sliding_window: if sliding_window is not None and xk.shape[2] > sliding_window and seq_length == 1:
xk = xk[:, :, -sliding_window:] xk = xk[:, :, -sliding_window:]
xv = xv[:, :, -sliding_window:] xv = xv[:, :, -sliding_window:]
attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None
@ -533,12 +533,12 @@ class Attention(nn.Module):
return self.o_proj(output), present_key_value return self.o_proj(output), present_key_value
class MLP(nn.Module): class MLP(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None): def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None, intermediate_size=None):
super().__init__() super().__init__()
ops = ops or nn intermediate_size = intermediate_size or config.intermediate_size
self.gate_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype) self.gate_proj = ops.Linear(config.hidden_size, intermediate_size, bias=False, device=device, dtype=dtype)
self.up_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype) self.up_proj = ops.Linear(config.hidden_size, intermediate_size, bias=False, device=device, dtype=dtype)
self.down_proj = ops.Linear(config.intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype) self.down_proj = ops.Linear(intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype)
if config.mlp_activation == "silu": if config.mlp_activation == "silu":
self.activation = torch.nn.functional.silu self.activation = torch.nn.functional.silu
elif config.mlp_activation == "gelu_pytorch_tanh": elif config.mlp_activation == "gelu_pytorch_tanh":
@ -647,24 +647,25 @@ class TransformerBlockGemma2(nn.Module):
return x, present_key_value return x, present_key_value
def _make_scaled_embedding(ops, vocab_size, hidden_size, scale, device, dtype):
class ScaledEmbedding(ops.Embedding):
def forward(self, input_ids, out_dtype=None):
return super().forward(input_ids, out_dtype=out_dtype) * scale
return ScaledEmbedding(vocab_size, hidden_size, device=device, dtype=dtype)
class Llama2_(nn.Module): class Llama2_(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None): def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__() super().__init__()
self.config = config self.config = config
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.embed_tokens = ops.Embedding(
config.vocab_size,
config.hidden_size,
device=device,
dtype=dtype
)
if self.config.transformer_type == "gemma2" or self.config.transformer_type == "gemma3": if self.config.transformer_type == "gemma2" or self.config.transformer_type == "gemma3":
transformer = TransformerBlockGemma2 transformer = TransformerBlockGemma2
self.normalize_in = True self.embed_tokens = _make_scaled_embedding(ops, config.vocab_size, config.hidden_size, config.hidden_size ** 0.5, device, dtype)
else: else:
transformer = TransformerBlock transformer = TransformerBlock
self.normalize_in = False self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
transformer(config, index=i, device=device, dtype=dtype, ops=ops) transformer(config, index=i, device=device, dtype=dtype, ops=ops)
@ -690,15 +691,12 @@ class Llama2_(nn.Module):
self.config.rope_dims, self.config.rope_dims,
device=device) device=device)
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None): def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None, input_ids=None):
if embeds is not None: if embeds is not None:
x = embeds x = embeds
else: else:
x = self.embed_tokens(x, out_dtype=dtype) x = self.embed_tokens(x, out_dtype=dtype)
if self.normalize_in:
x *= self.config.hidden_size ** 0.5
seq_len = x.shape[1] seq_len = x.shape[1]
past_len = 0 past_len = 0
if past_key_values is not None and len(past_key_values) > 0: if past_key_values is not None and len(past_key_values) > 0:
@ -850,7 +848,7 @@ class BaseGenerate:
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0)) torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
return past_key_values return past_key_values
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0): def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0, initial_input_ids=None):
device = embeds.device device = embeds.device
if stop_tokens is None: if stop_tokens is None:
@ -875,14 +873,16 @@ class BaseGenerate:
pbar = comfy.utils.ProgressBar(max_length) pbar = comfy.utils.ProgressBar(max_length)
# Generation loop # Generation loop
current_input_ids = initial_input_ids
for step in tqdm(range(max_length), desc="Generating tokens"): for step in tqdm(range(max_length), desc="Generating tokens"):
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values) x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values, input_ids=current_input_ids)
logits = self.logits(x)[:, -1] logits = self.logits(x)[:, -1]
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty) next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty)
token_id = next_token[0].item() token_id = next_token[0].item()
generated_token_ids.append(token_id) generated_token_ids.append(token_id)
embeds = self.model.embed_tokens(next_token).to(execution_dtype) embeds = self.model.embed_tokens(next_token).to(execution_dtype)
current_input_ids = next_token if initial_input_ids is not None else None
pbar.update(1) pbar.update(1)
if token_id in stop_tokens: if token_id in stop_tokens:

View File

@ -93,8 +93,7 @@ class Gemma3_12BModel(sd1_clip.SDClipModel):
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty): def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty):
tokens_only = [[t[0] for t in b] for b in tokens] tokens_only = [[t[0] for t in b] for b in tokens]
embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device) embeds, _, _, _ = self.process_tokens(tokens_only, self.execution_device)
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106], presence_penalty=presence_penalty) # 106 is <end_of_turn> return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106], presence_penalty=presence_penalty) # 106 is <end_of_turn>
class DualLinearProjection(torch.nn.Module): class DualLinearProjection(torch.nn.Module):

View File

@ -50,8 +50,7 @@ class Gemma3_4B_Vision_Model(sd1_clip.SDClipModel):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B_Vision, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B_Vision, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
def process_tokens(self, tokens, device): def process_tokens(self, tokens, device):
embeds, _, _, embeds_info = super().process_tokens(tokens, device) embeds, _, _, _ = super().process_tokens(tokens, device)
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
return embeds return embeds
class LuminaModel(sd1_clip.SD1ClipModel): class LuminaModel(sd1_clip.SD1ClipModel):

View File

@ -408,8 +408,6 @@ class Qwen35Transformer(Llama2_):
nn.Module.__init__(self) nn.Module.__init__(self)
self.config = config self.config = config
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.normalize_in = False
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype) self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
Qwen35TransformerBlock(config, index=i, device=device, dtype=dtype, ops=ops) Qwen35TransformerBlock(config, index=i, device=device, dtype=dtype, ops=ops)

View File

@ -1446,10 +1446,3 @@ def deepcopy_list_dict(obj, memo=None):
memo[obj_id] = res memo[obj_id] = res
return res return res
def normalize_image_embeddings(embeds, embeds_info, scale_factor):
"""Normalize image embeddings to match text embedding scale"""
for info in embeds_info:
if info.get("type") == "image":
start_idx = info["index"]
end_idx = start_idx + info["size"]
embeds[:, start_idx:end_idx, :] /= scale_factor

View File

@ -5,12 +5,95 @@ This module handles capability negotiation between frontend and backend,
allowing graceful protocol evolution while maintaining backward compatibility. allowing graceful protocol evolution while maintaining backward compatibility.
""" """
from typing import Any import logging
from typing import Any, TypedDict
from comfy.cli_args import args from comfy.cli_args import args
class FeatureFlagInfo(TypedDict):
type: str
default: Any
description: str
# Registry of known CLI-settable feature flags.
# Launchers can query this via --list-feature-flags to discover valid flags.
CLI_FEATURE_FLAG_REGISTRY: dict[str, FeatureFlagInfo] = {
"show_signin_button": {
"type": "bool",
"default": False,
"description": "Show the sign-in button in the frontend even when not signed in",
},
}
def _coerce_bool(v: str) -> bool:
"""Strict bool coercion: only 'true'/'false' (case-insensitive).
Anything else raises ValueError so the caller can warn and drop the flag,
rather than silently treating typos like 'ture' or 'yes' as False.
"""
lower = v.lower()
if lower == "true":
return True
if lower == "false":
return False
raise ValueError(f"expected 'true' or 'false', got {v!r}")
_COERCE_FNS: dict[str, Any] = {
"bool": _coerce_bool,
"int": lambda v: int(v),
"float": lambda v: float(v),
}
def _coerce_flag_value(key: str, raw_value: str) -> Any:
"""Coerce a raw string value using the registry type, or keep as string.
Returns the raw string if the key is unregistered or the type is unknown.
Raises ValueError/TypeError if the key is registered with a known type but
the value cannot be coerced; callers are expected to warn and drop the flag.
"""
info = CLI_FEATURE_FLAG_REGISTRY.get(key)
if info is None:
return raw_value
coerce = _COERCE_FNS.get(info["type"])
if coerce is None:
return raw_value
return coerce(raw_value)
def _parse_cli_feature_flags() -> dict[str, Any]:
"""Parse --feature-flag key=value pairs from CLI args into a dict.
Items without '=' default to the value 'true' (bare flag form).
Flags whose value cannot be coerced to the registered type are dropped
with a warning, so a typo like '--feature-flag some_bool=ture' does not
silently take effect as the wrong value.
"""
result: dict[str, Any] = {}
for item in getattr(args, "feature_flag", []):
key, sep, raw_value = item.partition("=")
key = key.strip()
if not key:
continue
if not sep:
raw_value = "true"
try:
result[key] = _coerce_flag_value(key, raw_value.strip())
except (ValueError, TypeError) as e:
info = CLI_FEATURE_FLAG_REGISTRY.get(key, {})
logging.warning(
"Could not coerce --feature-flag %s=%r to %s (%s); dropping flag.",
key, raw_value.strip(), info.get("type", "?"), e,
)
return result
# Default server capabilities # Default server capabilities
SERVER_FEATURE_FLAGS: dict[str, Any] = { _CORE_FEATURE_FLAGS: dict[str, Any] = {
"supports_preview_metadata": True, "supports_preview_metadata": True,
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes "max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
"extension": {"manager": {"supports_v4": True}}, "extension": {"manager": {"supports_v4": True}},
@ -18,6 +101,11 @@ SERVER_FEATURE_FLAGS: dict[str, Any] = {
"assets": args.enable_assets, "assets": args.enable_assets,
} }
# CLI-provided flags cannot overwrite core flags
_cli_flags = {k: v for k, v in _parse_cli_feature_flags().items() if k not in _CORE_FEATURE_FLAGS}
SERVER_FEATURE_FLAGS: dict[str, Any] = {**_CORE_FEATURE_FLAGS, **_cli_flags}
def get_connection_feature( def get_connection_feature(
sockets_metadata: dict[str, dict[str, Any]], sockets_metadata: dict[str, dict[str, Any]],

View File

@ -395,7 +395,6 @@ class Combo(ComfyTypeIO):
@comfytype(io_type="COMBO") @comfytype(io_type="COMBO")
class MultiCombo(ComfyTypeI): class MultiCombo(ComfyTypeI):
'''Multiselect Combo input (dropdown for selecting potentially more than one value).''' '''Multiselect Combo input (dropdown for selecting potentially more than one value).'''
# TODO: something is wrong with the serialization, frontend does not recognize it as multiselect
Type = list[str] Type = list[str]
class Input(Combo.Input): class Input(Combo.Input):
def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
@ -408,12 +407,14 @@ class MultiCombo(ComfyTypeI):
self.default: list[str] self.default: list[str]
def as_dict(self): def as_dict(self):
to_return = super().as_dict() | prune_dict({ # Frontend expects `multi_select` to be an object config (not a boolean).
"multi_select": self.multiselect, # Keep top-level `multiselect` from Combo.Input for backwards compatibility.
"placeholder": self.placeholder, return super().as_dict() | prune_dict({
"chip": self.chip, "multi_select": prune_dict({
"placeholder": self.placeholder,
"chip": self.chip,
}),
}) })
return to_return
@comfytype(io_type="IMAGE") @comfytype(io_type="IMAGE")
class Image(ComfyTypeIO): class Image(ComfyTypeIO):

View File

@ -1,15 +1,12 @@
from __future__ import annotations from __future__ import annotations
import torch
from enum import Enum from enum import Enum
from typing import Optional, Union from typing import Optional, Union
import torch
from pydantic import BaseModel, Field, confloat from pydantic import BaseModel, Field, confloat
class LumaIO: class LumaIO:
LUMA_REF = "LUMA_REF" LUMA_REF = "LUMA_REF"
LUMA_CONCEPTS = "LUMA_CONCEPTS" LUMA_CONCEPTS = "LUMA_CONCEPTS"
@ -183,13 +180,13 @@ class LumaAssets(BaseModel):
class LumaImageRef(BaseModel): class LumaImageRef(BaseModel):
'''Used for image gen''' """Used for image gen"""
url: str = Field(..., description='The URL of the image reference') url: str = Field(..., description='The URL of the image reference')
weight: confloat(ge=0.0, le=1.0) = Field(..., description='The weight of the image reference') weight: confloat(ge=0.0, le=1.0) = Field(..., description='The weight of the image reference')
class LumaImageReference(BaseModel): class LumaImageReference(BaseModel):
'''Used for video gen''' """Used for video gen"""
type: Optional[str] = Field('image', description='Input type, defaults to image') type: Optional[str] = Field('image', description='Input type, defaults to image')
url: str = Field(..., description='The URL of the image') url: str = Field(..., description='The URL of the image')
@ -251,3 +248,32 @@ class LumaGeneration(BaseModel):
assets: Optional[LumaAssets] = Field(None, description='The assets of the generation') assets: Optional[LumaAssets] = Field(None, description='The assets of the generation')
model: str = Field(..., description='The model used for the generation') model: str = Field(..., description='The model used for the generation')
request: Union[LumaGenerationRequest, LumaImageGenerationRequest] = Field(..., description="The request used for the generation") request: Union[LumaGenerationRequest, LumaImageGenerationRequest] = Field(..., description="The request used for the generation")
class Luma2ImageRef(BaseModel):
url: str | None = None
data: str | None = None
media_type: str | None = None
class Luma2GenerationRequest(BaseModel):
prompt: str = Field(..., min_length=1, max_length=6000)
model: str | None = None
type: str | None = None
aspect_ratio: str | None = None
style: str | None = None
output_format: str | None = None
web_search: bool | None = None
image_ref: list[Luma2ImageRef] | None = None
source: Luma2ImageRef | None = None
class Luma2Generation(BaseModel):
id: str | None = None
type: str | None = None
state: str | None = None
model: str | None = None
created_at: str | None = None
output: list[LumaImageReference] | None = None
failure_reason: str | None = None
failure_code: str | None = None

View File

@ -56,14 +56,14 @@ class ModelResponseProperties(BaseModel):
instructions: str | None = Field(None) instructions: str | None = Field(None)
max_output_tokens: int | None = Field(None) max_output_tokens: int | None = Field(None)
model: str | None = Field(None) model: str | None = Field(None)
temperature: float | None = Field(1, description="Controls randomness in the response", ge=0.0, le=2.0) temperature: float | None = Field(None, description="Controls randomness in the response", ge=0.0, le=2.0)
top_p: float | None = Field( top_p: float | None = Field(
1, None,
description="Controls diversity of the response via nucleus sampling", description="Controls diversity of the response via nucleus sampling",
ge=0.0, ge=0.0,
le=1.0, le=1.0,
) )
truncation: str | None = Field("disabled", description="Allowed values: 'auto' or 'disabled'") truncation: str | None = Field(None, description="Allowed values: 'auto' or 'disabled'")
class ResponseProperties(BaseModel): class ResponseProperties(BaseModel):

View File

@ -1,4 +1,4 @@
from typing import Optional, Union from typing import Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -72,8 +72,11 @@ class VideoEnhancementFilter(BaseModel):
grain: Optional[float] = Field(None, description="Grain after AI model processing") grain: Optional[float] = Field(None, description="Grain after AI model processing")
grainSize: Optional[float] = Field(None, description="Size of generated grain") grainSize: Optional[float] = Field(None, description="Size of generated grain")
recoverOriginalDetailValue: Optional[float] = Field(None, description="Source details into the output video") recoverOriginalDetailValue: Optional[float] = Field(None, description="Source details into the output video")
creativity: Optional[str] = Field(None, description="Creativity level(high, low) for slc-1 only") creativity: float | str | None = Field(None, description="slc-1/slp-2.5: enum (low/middle/high). ast-2: decimal 0.0-1.0.")
isOptimizedMode: Optional[bool] = Field(None, description="Set to true for Starlight Creative (slc-1) only") isOptimizedMode: Optional[bool] = Field(None, description="Set to true for Starlight Creative (slc-1) only")
prompt: str | None = Field(None, description="Descriptive scene prompt (ast-2 only)")
sharp: float | None = Field(None, description="ast-2 pre-enhance sharpness")
realism: float | None = Field(None, description="ast-2 realism control")
class OutputInformationVideo(BaseModel): class OutputInformationVideo(BaseModel):
@ -90,7 +93,7 @@ class Overrides(BaseModel):
class CreateVideoRequest(BaseModel): class CreateVideoRequest(BaseModel):
source: CreateVideoRequestSource = Field(...) source: CreateVideoRequestSource = Field(...)
filters: list[Union[VideoFrameInterpolationFilter, VideoEnhancementFilter]] = Field(...) filters: list[VideoFrameInterpolationFilter | VideoEnhancementFilter] = Field(...)
output: OutputInformationVideo = Field(...) output: OutputInformationVideo = Field(...)
overrides: Overrides = Field(Overrides(isPaidDiffusion=True)) overrides: Overrides = Field(Overrides(isPaidDiffusion=True))

View File

@ -1,10 +1,11 @@
from typing import Optional
import torch import torch
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.luma import ( from comfy_api_nodes.apis.luma import (
Luma2Generation,
Luma2GenerationRequest,
Luma2ImageRef,
LumaAspectRatio, LumaAspectRatio,
LumaCharacterRef, LumaCharacterRef,
LumaConceptChain, LumaConceptChain,
@ -30,6 +31,7 @@ from comfy_api_nodes.util import (
download_url_to_video_output, download_url_to_video_output,
poll_op, poll_op,
sync_op, sync_op,
upload_image_to_comfyapi,
upload_images_to_comfyapi, upload_images_to_comfyapi,
validate_string, validate_string,
) )
@ -212,9 +214,9 @@ class LumaImageGenerationNode(IO.ComfyNode):
aspect_ratio: str, aspect_ratio: str,
seed, seed,
style_image_weight: float, style_image_weight: float,
image_luma_ref: Optional[LumaReferenceChain] = None, image_luma_ref: LumaReferenceChain | None = None,
style_image: Optional[torch.Tensor] = None, style_image: torch.Tensor | None = None,
character_image: Optional[torch.Tensor] = None, character_image: torch.Tensor | None = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=3) validate_string(prompt, strip_whitespace=True, min_length=3)
# handle image_luma_ref # handle image_luma_ref
@ -434,7 +436,7 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode):
duration: str, duration: str,
loop: bool, loop: bool,
seed, seed,
luma_concepts: Optional[LumaConceptChain] = None, luma_concepts: LumaConceptChain | None = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False, min_length=3) validate_string(prompt, strip_whitespace=False, min_length=3)
duration = duration if model != LumaVideoModel.ray_1_6 else None duration = duration if model != LumaVideoModel.ray_1_6 else None
@ -533,7 +535,6 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode):
], ],
is_api_node=True, is_api_node=True,
price_badge=PRICE_BADGE_VIDEO, price_badge=PRICE_BADGE_VIDEO,
) )
@classmethod @classmethod
@ -644,6 +645,293 @@ PRICE_BADGE_VIDEO = IO.PriceBadge(
) )
def _luma2_uni1_common_inputs(max_image_refs: int) -> list:
return [
IO.Combo.Input(
"style",
options=["auto", "manga"],
default="auto",
tooltip="Style preset. 'auto' picks based on the prompt; "
"'manga' applies a manga/anime aesthetic and requires a portrait "
"aspect ratio (2:3, 9:16, 1:2, 1:3).",
),
IO.Boolean.Input(
"web_search",
default=False,
tooltip="Search the web for visual references before generating.",
),
IO.Autogrow.Input(
"image_ref",
template=IO.Autogrow.TemplateNames(
IO.Image.Input("image"),
names=[f"image_{i}" for i in range(1, max_image_refs + 1)],
min=0,
),
optional=True,
tooltip=f"Up to {max_image_refs} reference images for style/content guidance.",
),
]
async def _luma2_upload_image_refs(
cls: type[IO.ComfyNode],
refs: dict | None,
max_count: int,
) -> list[Luma2ImageRef] | None:
if not refs:
return None
out: list[Luma2ImageRef] = []
for key in refs:
url = await upload_image_to_comfyapi(cls, refs[key])
out.append(Luma2ImageRef(url=url))
if len(out) > max_count:
raise ValueError(f"Maximum {max_count} reference images are allowed.")
return out or None
async def _luma2_submit_and_poll(
cls: type[IO.ComfyNode],
request: Luma2GenerationRequest,
) -> Input.Image:
initial = await sync_op(
cls,
ApiEndpoint(path="/proxy/luma_2/generations", method="POST"),
response_model=Luma2Generation,
data=request,
)
if not initial.id:
raise RuntimeError("Luma 2 API did not return a generation id.")
final = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/luma_2/generations/{initial.id}", method="GET"),
response_model=Luma2Generation,
status_extractor=lambda r: r.state,
progress_extractor=lambda r: None,
)
if not final.output:
msg = final.failure_reason or "no output returned"
raise RuntimeError(f"Luma 2 generation failed: {msg}")
url = final.output[0].url
if not url:
raise RuntimeError("Luma 2 generation completed without an output URL.")
return await download_url_to_image_tensor(url)
class LumaImageNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaImageNode2",
display_name="Luma UNI-1 Image",
category="api node/image/Luma",
description="Generate images from text using the Luma UNI-1 model.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text description of the desired image. 16000 characters.",
),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"uni-1",
[
IO.Combo.Input(
"aspect_ratio",
options=[
"auto",
"3:1",
"2:1",
"16:9",
"3:2",
"1:1",
"2:3",
"9:16",
"1:2",
"1:3",
],
default="auto",
tooltip="Output image aspect ratio. 'auto' lets "
"the model pick based on the prompt.",
),
*_luma2_uni1_common_inputs(max_image_refs=9),
],
),
IO.DynamicCombo.Option(
"uni-1-max",
[
IO.Combo.Input(
"aspect_ratio",
options=[
"auto",
"3:1",
"2:1",
"16:9",
"3:2",
"1:1",
"2:3",
"9:16",
"1:2",
"1:3",
],
default="auto",
tooltip="Output image aspect ratio. 'auto' lets "
"the model pick based on the prompt.",
),
*_luma2_uni1_common_inputs(max_image_refs=9),
],
),
],
tooltip="Model to use for generation.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model"], input_groups=["model.image_ref"]),
expr="""
(
$m := widgets.model;
$refs := $lookup(inputGroups, "model.image_ref");
$base := $m = "uni-1-max" ? 0.1 : 0.0404;
{"type":"usd","usd": $round($base + 0.003 * $refs, 4)}
)
""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
model: dict,
seed: int,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=6000)
aspect_ratio = model["aspect_ratio"]
style = model["style"]
allowed_manga_ratios = {"2:3", "9:16", "1:2", "1:3"}
if style == "manga" and aspect_ratio != "auto" and aspect_ratio not in allowed_manga_ratios:
raise ValueError(
f"'manga' style requires a portrait aspect ratio "
f"({', '.join(sorted(allowed_manga_ratios))}) or 'auto'; got '{aspect_ratio}'."
)
request = Luma2GenerationRequest(
prompt=prompt,
model=model["model"],
type="image",
aspect_ratio=aspect_ratio if aspect_ratio != "auto" else None,
style=style if style != "auto" else None,
output_format="png",
web_search=model["web_search"],
image_ref=await _luma2_upload_image_refs(cls, model.get("image_ref"), max_count=9),
)
return IO.NodeOutput(await _luma2_submit_and_poll(cls, request))
class LumaImageEditNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaImageEditNode2",
display_name="Luma UNI-1 Image Edit",
category="api node/image/Luma",
description="Edit an existing image with a text prompt using the Luma UNI-1 model.",
inputs=[
IO.Image.Input(
"source",
tooltip="Source image to edit.",
),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Description of the desired edit. 16000 characters.",
),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"uni-1",
_luma2_uni1_common_inputs(max_image_refs=8),
),
IO.DynamicCombo.Option(
"uni-1-max",
_luma2_uni1_common_inputs(max_image_refs=8),
),
],
tooltip="Model to use for editing.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model"], input_groups=["model.image_ref"]),
expr="""
(
$m := widgets.model;
$refs := $lookup(inputGroups, "model.image_ref");
$base := $m = "uni-1-max" ? 0.103 : 0.0434;
{"type":"usd","usd": $round($base + 0.003 * $refs, 4)}
)
""",
),
)
@classmethod
async def execute(
cls,
source: Input.Image,
prompt: str,
model: dict,
seed: int,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=6000)
request = Luma2GenerationRequest(
prompt=prompt,
model=model["model"],
type="image_edit",
source=Luma2ImageRef(url=await upload_image_to_comfyapi(cls, source)),
style=model["style"] if model["style"] != "auto" else None,
output_format="png",
web_search=model["web_search"],
image_ref=await _luma2_upload_image_refs(cls, model.get("image_ref"), max_count=8),
)
return IO.NodeOutput(await _luma2_submit_and_poll(cls, request))
class LumaExtension(ComfyExtension): class LumaExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[IO.ComfyNode]]: async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -654,6 +942,8 @@ class LumaExtension(ComfyExtension):
LumaImageToVideoGenerationNode, LumaImageToVideoGenerationNode,
LumaReferenceNode, LumaReferenceNode,
LumaConceptsNode, LumaConceptsNode,
LumaImageNode,
LumaImageEditNode,
] ]

View File

@ -39,16 +39,18 @@ STARTING_POINT_ID_PATTERN = r"<starting_point_id:(.*)>"
class SupportedOpenAIModel(str, Enum): class SupportedOpenAIModel(str, Enum):
o4_mini = "o4-mini" gpt_5_5_pro = "gpt-5.5-pro"
o1 = "o1" gpt_5_5 = "gpt-5.5"
o3 = "o3"
o1_pro = "o1-pro"
gpt_4_1 = "gpt-4.1"
gpt_4_1_mini = "gpt-4.1-mini"
gpt_4_1_nano = "gpt-4.1-nano"
gpt_5 = "gpt-5" gpt_5 = "gpt-5"
gpt_5_mini = "gpt-5-mini" gpt_5_mini = "gpt-5-mini"
gpt_5_nano = "gpt-5-nano" gpt_5_nano = "gpt-5-nano"
gpt_4_1 = "gpt-4.1"
gpt_4_1_mini = "gpt-4.1-mini"
gpt_4_1_nano = "gpt-4.1-nano"
o4_mini = "o4-mini"
o3 = "o3"
o1_pro = "o1-pro"
o1 = "o1"
async def validate_and_cast_response(response, timeout: int = None) -> torch.Tensor: async def validate_and_cast_response(response, timeout: int = None) -> torch.Tensor:
@ -739,6 +741,16 @@ class OpenAIChatNode(IO.ComfyNode):
"usd": [0.002, 0.008], "usd": [0.002, 0.008],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
} }
: $contains($m, "gpt-5.5-pro") ? {
"type": "list_usd",
"usd": [0.03, 0.18],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "gpt-5.5") ? {
"type": "list_usd",
"usd": [0.005, 0.03],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "gpt-5-nano") ? { : $contains($m, "gpt-5-nano") ? {
"type": "list_usd", "type": "list_usd",
"usd": [0.00005, 0.0004], "usd": [0.00005, 0.0004],

View File

@ -33,7 +33,7 @@ class OpenAIVideoSora2(IO.ComfyNode):
def define_schema(cls): def define_schema(cls):
return IO.Schema( return IO.Schema(
node_id="OpenAIVideoSora2", node_id="OpenAIVideoSora2",
display_name="OpenAI Sora - Video (Deprecated)", display_name="OpenAI Sora - Video (DEPRECATED)",
category="api node/video/Sora", category="api node/video/Sora",
description=( description=(
"OpenAI video and audio generation.\n\n" "OpenAI video and audio generation.\n\n"

View File

@ -36,11 +36,15 @@ from comfy_api_nodes.util import (
) )
UPSCALER_MODELS_MAP = { UPSCALER_MODELS_MAP = {
"Astra 2": "ast-2",
"Starlight (Astra) Fast": "slf-1", "Starlight (Astra) Fast": "slf-1",
"Starlight (Astra) Creative": "slc-1", "Starlight (Astra) Creative": "slc-1",
"Starlight Precise 2.5": "slp-2.5", "Starlight Precise 2.5": "slp-2.5",
} }
AST2_MAX_FRAMES = 9000
AST2_MAX_FRAMES_WITH_PROMPT = 450
class TopazImageEnhance(IO.ComfyNode): class TopazImageEnhance(IO.ComfyNode):
@classmethod @classmethod
@ -230,13 +234,20 @@ class TopazVideoEnhance(IO.ComfyNode):
def define_schema(cls): def define_schema(cls):
return IO.Schema( return IO.Schema(
node_id="TopazVideoEnhance", node_id="TopazVideoEnhance",
display_name="Topaz Video Enhance", display_name="Topaz Video Enhance (Legacy)",
category="api node/video/Topaz", category="api node/video/Topaz",
description="Breathe new life into video with powerful upscaling and recovery technology.", description="Breathe new life into video with powerful upscaling and recovery technology.",
inputs=[ inputs=[
IO.Video.Input("video"), IO.Video.Input("video"),
IO.Boolean.Input("upscaler_enabled", default=True), IO.Boolean.Input("upscaler_enabled", default=True),
IO.Combo.Input("upscaler_model", options=list(UPSCALER_MODELS_MAP.keys())), IO.Combo.Input(
"upscaler_model",
options=[
"Starlight (Astra) Fast",
"Starlight (Astra) Creative",
"Starlight Precise 2.5",
],
),
IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]), IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),
IO.Combo.Input( IO.Combo.Input(
"upscaler_creativity", "upscaler_creativity",
@ -304,6 +315,7 @@ class TopazVideoEnhance(IO.ComfyNode):
IO.Hidden.unique_id, IO.Hidden.unique_id,
], ],
is_api_node=True, is_api_node=True,
is_deprecated=True,
) )
@classmethod @classmethod
@ -457,12 +469,357 @@ class TopazVideoEnhance(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_video_output(final_response.download.url)) return IO.NodeOutput(await download_url_to_video_output(final_response.download.url))
class TopazVideoEnhanceV2(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TopazVideoEnhanceV2",
display_name="Topaz Video Enhance",
category="api node/video/Topaz",
description="Breathe new life into video with powerful upscaling and recovery technology.",
inputs=[
IO.Video.Input("video"),
IO.DynamicCombo.Input(
"upscaler_model",
options=[
IO.DynamicCombo.Option(
"Astra 2",
[
IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),
IO.Float.Input(
"creativity",
default=0.5,
min=0.0,
max=1.0,
step=0.1,
display_mode=IO.NumberDisplay.slider,
tooltip="Creative strength of the upscale.",
),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Optional descriptive (not instructive) scene prompt."
f"Capping input at {AST2_MAX_FRAMES_WITH_PROMPT} frames (~15s @ 30fps) when set.",
),
IO.Float.Input(
"sharp",
default=0.5,
min=0.0,
max=1.0,
step=0.01,
display_mode=IO.NumberDisplay.slider,
tooltip="Pre-enhance sharpness: "
"0.0=Gaussian blur, 0.5=passthrough (default), 1.0=USM sharpening.",
advanced=True,
),
IO.Float.Input(
"realism",
default=0.0,
min=0.0,
max=1.0,
step=0.01,
display_mode=IO.NumberDisplay.slider,
tooltip="Pulls output toward photographic realism."
"Leave at 0 for the model default.",
advanced=True,
),
],
),
IO.DynamicCombo.Option(
"Starlight (Astra) Fast",
[IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),],
),
IO.DynamicCombo.Option(
"Starlight (Astra) Creative",
[
IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),
IO.Combo.Input(
"creativity",
options=["low", "middle", "high"],
default="low",
tooltip="Creative strength of the upscale.",
),
],
),
IO.DynamicCombo.Option(
"Starlight Precise 2.5",
[IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"])],
),
IO.DynamicCombo.Option("Disabled", []),
],
),
IO.DynamicCombo.Input(
"interpolation_model",
options=[
IO.DynamicCombo.Option("Disabled", []),
IO.DynamicCombo.Option(
"apo-8",
[
IO.Int.Input(
"interpolation_frame_rate",
default=60,
min=15,
max=240,
display_mode=IO.NumberDisplay.number,
tooltip="Output frame rate.",
),
IO.Int.Input(
"interpolation_slowmo",
default=1,
min=1,
max=16,
display_mode=IO.NumberDisplay.number,
tooltip="Slow-motion factor applied to the input video. "
"For example, 2 makes the output twice as slow and doubles the duration.",
advanced=True,
),
IO.Boolean.Input(
"interpolation_duplicate",
default=False,
tooltip="Analyze the input for duplicate frames and remove them.",
advanced=True,
),
IO.Float.Input(
"interpolation_duplicate_threshold",
default=0.01,
min=0.001,
max=0.1,
step=0.001,
display_mode=IO.NumberDisplay.number,
tooltip="Detection sensitivity for duplicate frames.",
advanced=True,
),
],
),
],
),
IO.Combo.Input(
"dynamic_compression_level",
options=["Low", "Mid", "High"],
default="Low",
tooltip="CQP level.",
optional=True,
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=[
"upscaler_model",
"upscaler_model.upscaler_resolution",
"interpolation_model",
]),
expr="""
(
$model := $lookup(widgets, "upscaler_model");
$res := $lookup(widgets, "upscaler_model.upscaler_resolution");
$interp := $lookup(widgets, "interpolation_model");
$is4k := $contains($res, "4k");
$hasInterp := $interp != "disabled";
$rates := {
"starlight (astra) fast": {"hd": 0.43, "uhd": 0.85},
"starlight precise 2.5": {"hd": 0.70, "uhd": 1.54},
"astra 2": {"hd": 1.72, "uhd": 2.85},
"starlight (astra) creative": {"hd": 2.25, "uhd": 3.99}
};
$surcharge := $is4k ? 0.28 : 0.14;
$entry := $lookup($rates, $model);
$base := $is4k ? $entry.uhd : $entry.hd;
$hi := $base + ($hasInterp ? $surcharge : 0);
$model = "disabled"
? {"type":"text","text":"Interpolation only"}
: ($hasInterp
? {"type":"text","text":"~" & $string($base) & "" & $string($hi) & " credits/src frame"}
: {"type":"text","text":"~" & $string($base) & " credits/src frame"})
)
""",
),
)
@classmethod
async def execute(
cls,
video: Input.Video,
upscaler_model: dict,
interpolation_model: dict,
dynamic_compression_level: str = "Low",
) -> IO.NodeOutput:
upscaler_choice = upscaler_model["upscaler_model"]
interpolation_choice = interpolation_model["interpolation_model"]
if upscaler_choice == "Disabled" and interpolation_choice == "Disabled":
raise ValueError("There is nothing to do: both upscaling and interpolation are disabled.")
validate_container_format_is_mp4(video)
src_width, src_height = video.get_dimensions()
src_frame_rate = int(video.get_frame_rate())
duration_sec = video.get_duration()
src_video_stream = video.get_stream_source()
target_width = src_width
target_height = src_height
target_frame_rate = src_frame_rate
filters = []
if upscaler_choice != "Disabled":
if "1080p" in upscaler_model["upscaler_resolution"]:
target_pixel_p = 1080
max_long_side = 1920
else:
target_pixel_p = 2160
max_long_side = 3840
ar = src_width / src_height
if src_width >= src_height:
# Landscape or Square; Attempt to set height to target (e.g., 2160), calculate width
target_height = target_pixel_p
target_width = int(target_height * ar)
# Check if width exceeds standard bounds (for ultra-wide e.g., 21:9 ARs)
if target_width > max_long_side:
target_width = max_long_side
target_height = int(target_width / ar)
else:
# Portrait; Attempt to set width to target (e.g., 2160), calculate height
target_width = target_pixel_p
target_height = int(target_width / ar)
# Check if height exceeds standard bounds
if target_height > max_long_side:
target_height = max_long_side
target_width = int(target_height * ar)
if target_width % 2 != 0:
target_width += 1
if target_height % 2 != 0:
target_height += 1
model_id = UPSCALER_MODELS_MAP[upscaler_choice]
if model_id == "slc-1":
filters.append(
VideoEnhancementFilter(
model=model_id,
creativity=upscaler_model["creativity"],
isOptimizedMode=True,
)
)
elif model_id == "ast-2":
n_frames = video.get_frame_count()
ast2_prompt = (upscaler_model["prompt"] or "").strip()
if ast2_prompt and n_frames > AST2_MAX_FRAMES_WITH_PROMPT:
raise ValueError(
f"Astra 2 with a prompt is limited to {AST2_MAX_FRAMES_WITH_PROMPT} input frames "
f"(~15s @ 30fps); video has {n_frames}. Clear the prompt or shorten the clip."
)
if n_frames > AST2_MAX_FRAMES:
raise ValueError(f"Astra 2 is limited to {AST2_MAX_FRAMES} input frames; video has {n_frames}.")
realism = upscaler_model["realism"]
filters.append(
VideoEnhancementFilter(
model=model_id,
creativity=upscaler_model["creativity"],
prompt=(ast2_prompt or None),
sharp=upscaler_model["sharp"],
realism=(realism if realism > 0 else None),
)
)
else:
filters.append(VideoEnhancementFilter(model=model_id))
if interpolation_choice != "Disabled":
target_frame_rate = interpolation_model["interpolation_frame_rate"]
filters.append(
VideoFrameInterpolationFilter(
model=interpolation_choice,
slowmo=interpolation_model["interpolation_slowmo"],
fps=interpolation_model["interpolation_frame_rate"],
duplicate=interpolation_model["interpolation_duplicate"],
duplicate_threshold=interpolation_model["interpolation_duplicate_threshold"],
),
)
initial_res = await sync_op(
cls,
ApiEndpoint(path="/proxy/topaz/video/", method="POST"),
response_model=CreateVideoResponse,
data=CreateVideoRequest(
source=CreateVideoRequestSource(
container="mp4",
size=get_fs_object_size(src_video_stream),
duration=int(duration_sec),
frameCount=video.get_frame_count(),
frameRate=src_frame_rate,
resolution=Resolution(width=src_width, height=src_height),
),
filters=filters,
output=OutputInformationVideo(
resolution=Resolution(width=target_width, height=target_height),
frameRate=target_frame_rate,
audioCodec="AAC",
audioTransfer="Copy",
dynamicCompressionLevel=dynamic_compression_level,
),
),
wait_label="Creating task",
final_label_on_success="Task created",
)
upload_res = await sync_op(
cls,
ApiEndpoint(
path=f"/proxy/topaz/video/{initial_res.requestId}/accept",
method="PATCH",
),
response_model=VideoAcceptResponse,
wait_label="Preparing upload",
final_label_on_success="Upload started",
)
if len(upload_res.urls) > 1:
raise NotImplementedError(
"Large files are not currently supported. Please open an issue in the ComfyUI repository."
)
async with aiohttp.ClientSession(headers={"Content-Type": "video/mp4"}) as session:
if isinstance(src_video_stream, BytesIO):
src_video_stream.seek(0)
async with session.put(upload_res.urls[0], data=src_video_stream, raise_for_status=True) as res:
upload_etag = res.headers["Etag"]
else:
with builtins.open(src_video_stream, "rb") as video_file:
async with session.put(upload_res.urls[0], data=video_file, raise_for_status=True) as res:
upload_etag = res.headers["Etag"]
await sync_op(
cls,
ApiEndpoint(
path=f"/proxy/topaz/video/{initial_res.requestId}/complete-upload",
method="PATCH",
),
response_model=VideoCompleteUploadResponse,
data=VideoCompleteUploadRequest(
uploadResults=[
VideoCompleteUploadRequestPart(
partNum=1,
eTag=upload_etag,
),
],
),
wait_label="Finalizing upload",
final_label_on_success="Upload completed",
)
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/topaz/video/{initial_res.requestId}/status"),
response_model=VideoStatusResponse,
status_extractor=lambda x: x.status,
progress_extractor=lambda x: getattr(x, "progress", 0),
price_extractor=lambda x: (x.estimates.cost[0] * 0.08 if x.estimates and x.estimates.cost[0] else None),
poll_interval=10.0,
)
return IO.NodeOutput(await download_url_to_video_output(final_response.download.url))
class TopazExtension(ComfyExtension): class TopazExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[IO.ComfyNode]]: async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [ return [
TopazImageEnhance, TopazImageEnhance,
TopazVideoEnhance, TopazVideoEnhance,
TopazVideoEnhanceV2,
] ]

View File

@ -19,6 +19,8 @@ from comfy import utils
from comfy_api.latest import IO from comfy_api.latest import IO
from server import PromptServer from server import PromptServer
from comfy.deploy_environment import get_deploy_environment
from . import request_logger from . import request_logger
from ._helpers import ( from ._helpers import (
default_base_url, default_base_url,
@ -624,6 +626,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"} payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative? if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
payload_headers.update(get_auth_header(cfg.node_cls)) payload_headers.update(get_auth_header(cfg.node_cls))
payload_headers["Comfy-Env"] = get_deploy_environment()
if cfg.endpoint.headers: if cfg.endpoint.headers:
payload_headers.update(cfg.endpoint.headers) payload_headers.update(cfg.endpoint.headers)

View File

@ -199,6 +199,9 @@ class FILMNet(nn.Module):
def get_dtype(self): def get_dtype(self):
return self.extract.extract_sublevels.convs[0][0].conv.weight.dtype return self.extract.extract_sublevels.convs[0][0].conv.weight.dtype
def memory_used_forward(self, shape, dtype):
return 1700 * shape[1] * shape[2] * dtype.itemsize
def _build_warp_grids(self, H, W, device): def _build_warp_grids(self, H, W, device):
"""Pre-compute warp grids for all pyramid levels.""" """Pre-compute warp grids for all pyramid levels."""
if (H, W) in self._warp_grids: if (H, W) in self._warp_grids:

View File

@ -74,6 +74,9 @@ class IFNet(nn.Module):
def get_dtype(self): def get_dtype(self):
return self.encode.cnn0.weight.dtype return self.encode.cnn0.weight.dtype
def memory_used_forward(self, shape, dtype):
return 300 * shape[1] * shape[2] * dtype.itemsize
def _build_warp_grids(self, H, W, device): def _build_warp_grids(self, H, W, device):
if (H, W) in self._warp_grids: if (H, W) in self._warp_grids:
return return

View File

@ -42,7 +42,7 @@ class TextEncodeAceStepAudio15(IO.ComfyNode):
IO.Int.Input("bpm", default=120, min=10, max=300), IO.Int.Input("bpm", default=120, min=10, max=300),
IO.Float.Input("duration", default=120.0, min=0.0, max=2000.0, step=0.1), IO.Float.Input("duration", default=120.0, min=0.0, max=2000.0, step=0.1),
IO.Combo.Input("timesignature", options=['2', '3', '4', '6']), IO.Combo.Input("timesignature", options=['2', '3', '4', '6']),
IO.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]), IO.Combo.Input("language", options=['ar', 'az', 'bg', 'bn', 'ca', 'cs', 'da', 'de', 'el', 'en', 'es', 'fa', 'fi', 'fr', 'he', 'hi', 'hr', 'ht', 'hu', 'id', 'is', 'it', 'ja', 'ko', 'la', 'lt', 'ms', 'ne', 'nl', 'no', 'pa', 'pl', 'pt', 'ro', 'ru', 'sa', 'sk', 'sr', 'sv', 'sw', 'ta', 'te', 'th', 'tl', 'tr', 'uk', 'ur', 'vi', 'yue', 'zh', 'unknown'], default='en'),
IO.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]), IO.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]),
IO.Boolean.Input("generate_audio_codes", default=True, tooltip="Enable the LLM that generates audio codes. This can be slow but will increase the quality of the generated audio. Turn this off if you are giving the model an audio reference.", advanced=True), IO.Boolean.Input("generate_audio_codes", default=True, tooltip="Enable the LLM that generates audio codes. This can be slow but will increase the quality of the generated audio. Turn this off if you are giving the model an audio reference.", advanced=True),
IO.Float.Input("cfg_scale", default=2.0, min=0.0, max=100.0, step=0.1, advanced=True), IO.Float.Input("cfg_scale", default=2.0, min=0.0, max=100.0, step=0.1, advanced=True),

View File

@ -0,0 +1,84 @@
"""
ComfyUI nodes for autoregressive video generation (Causal Forcing, Self-Forcing, etc.).
- EmptyARVideoLatent: create 5D [B, C, T, H, W] video latent tensors
- SamplerARVideo: SAMPLER for the block-by-block autoregressive denoising loop
"""
import torch
from typing_extensions import override
import comfy.model_management
import comfy.samplers
from comfy_api.latest import ComfyExtension, io
class EmptyARVideoLatent(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="EmptyARVideoLatent",
category="latent/video",
inputs=[
io.Int.Input("width", default=832, min=16, max=8192, step=16),
io.Int.Input("height", default=480, min=16, max=8192, step=16),
io.Int.Input("length", default=81, min=1, max=1024, step=4),
io.Int.Input("batch_size", default=1, min=1, max=64),
],
outputs=[
io.Latent.Output(display_name="LATENT"),
],
)
@classmethod
def execute(cls, width, height, length, batch_size) -> io.NodeOutput:
lat_t = ((length - 1) // 4) + 1
latent = torch.zeros(
[batch_size, 16, lat_t, height // 8, width // 8],
device=comfy.model_management.intermediate_device(),
)
return io.NodeOutput({"samples": latent})
class SamplerARVideo(io.ComfyNode):
"""Sampler for autoregressive video models (Causal Forcing, Self-Forcing).
All AR-loop parameters are owned by this node so they live in the workflow.
Add new widgets here as the AR sampler grows new options.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SamplerARVideo",
display_name="Sampler AR Video",
category="sampling/custom_sampling/samplers",
inputs=[
io.Int.Input(
"num_frame_per_block",
default=1, min=1, max=64,
tooltip="Frames per autoregressive block. 1 = framewise, "
"3 = chunkwise. Must match the checkpoint's training mode.",
),
],
outputs=[io.Sampler.Output()],
)
@classmethod
def execute(cls, num_frame_per_block) -> io.NodeOutput:
extra_options = {
"num_frame_per_block": num_frame_per_block,
}
return io.NodeOutput(comfy.samplers.ksampler("ar_video", extra_options))
class ARVideoExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
EmptyARVideoLatent,
SamplerARVideo,
]
async def comfy_entrypoint() -> ARVideoExtension:
return ARVideoExtension()

View File

@ -202,14 +202,11 @@ class JoinImageWithAlpha(io.ComfyNode):
@classmethod @classmethod
def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput: def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput:
batch_size = min(len(image), len(alpha)) batch_size = max(len(image), len(alpha))
out_images = []
alpha = 1.0 - resize_mask(alpha, image.shape[1:]) alpha = 1.0 - resize_mask(alpha, image.shape[1:])
for i in range(batch_size): alpha = comfy.utils.repeat_to_batch_size(alpha, batch_size)
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2)) image = comfy.utils.repeat_to_batch_size(image, batch_size)
return io.NodeOutput(torch.cat((image[..., :3], alpha.unsqueeze(-1)), dim=-1))
return io.NodeOutput(torch.stack(out_images))
class CompositingExtension(ComfyExtension): class CompositingExtension(ComfyExtension):

View File

@ -29,6 +29,7 @@ class ContextWindowsManualNode(io.ComfyNode):
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."), io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."), io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
io.Boolean.Input("causal_window_fix", default=True, tooltip="Whether to add a causal fix frame to non-0-indexed context windows."),
], ],
outputs=[ outputs=[
io.Model.Output(tooltip="The model with context windows applied during sampling."), io.Model.Output(tooltip="The model with context windows applied during sampling."),
@ -38,7 +39,7 @@ class ContextWindowsManualNode(io.ComfyNode):
@classmethod @classmethod
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool, def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool,
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model: cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False, causal_window_fix: bool=True) -> io.Model:
model = model.clone() model = model.clone()
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler( model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule), context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
@ -50,7 +51,8 @@ class ContextWindowsManualNode(io.ComfyNode):
dim=dim, dim=dim,
freenoise=freenoise, freenoise=freenoise,
cond_retain_index_list=cond_retain_index_list, cond_retain_index_list=cond_retain_index_list,
split_conds_to_windows=split_conds_to_windows split_conds_to_windows=split_conds_to_windows,
causal_window_fix=causal_window_fix,
) )
# make memory usage calculation only take into account the context window latents # make memory usage calculation only take into account the context window latents
comfy.context_windows.create_prepare_sampling_wrapper(model) comfy.context_windows.create_prepare_sampling_wrapper(model)

View File

@ -37,7 +37,7 @@ class FrameInterpolationModelLoader(io.ComfyNode):
model = cls._detect_and_load(sd) model = cls._detect_and_load(sd)
dtype = torch.float16 if model_management.should_use_fp16(model_management.get_torch_device()) else torch.float32 dtype = torch.float16 if model_management.should_use_fp16(model_management.get_torch_device()) else torch.float32
model.eval().to(dtype) model.eval().to(dtype)
patcher = comfy.model_patcher.ModelPatcher( patcher = comfy.model_patcher.CoreModelPatcher(
model, model,
load_device=model_management.get_torch_device(), load_device=model_management.get_torch_device(),
offload_device=model_management.unet_offload_device(), offload_device=model_management.unet_offload_device(),
@ -78,7 +78,7 @@ class FrameInterpolate(io.ComfyNode):
return io.Schema( return io.Schema(
node_id="FrameInterpolate", node_id="FrameInterpolate",
display_name="Frame Interpolate", display_name="Frame Interpolate",
category="image/video", category="video",
search_aliases=["rife", "film", "frame interpolation", "slow motion", "interpolate frames", "vfi"], search_aliases=["rife", "film", "frame interpolation", "slow motion", "interpolate frames", "vfi"],
inputs=[ inputs=[
FrameInterpolationModel.Input("interp_model"), FrameInterpolationModel.Input("interp_model"),
@ -98,16 +98,13 @@ class FrameInterpolate(io.ComfyNode):
if num_frames < 2 or multiplier < 2: if num_frames < 2 or multiplier < 2:
return io.NodeOutput(images) return io.NodeOutput(images)
model_management.load_model_gpu(interp_model)
device = interp_model.load_device device = interp_model.load_device
dtype = interp_model.model_dtype() dtype = interp_model.model_dtype()
inference_model = interp_model.model inference_model = interp_model.model
activation_mem = inference_model.memory_used_forward(images.shape, dtype)
# Free VRAM for inference activations (model weights + ~20x a single frame's worth) model_management.load_models_gpu([interp_model], memory_required=activation_mem)
H, W = images.shape[1], images.shape[2]
activation_mem = H * W * 3 * images.element_size() * 20
model_management.free_memory(activation_mem, device)
align = getattr(inference_model, "pad_align", 1) align = getattr(inference_model, "pad_align", 1)
H, W = images.shape[1], images.shape[2]
# Prepare a single padded frame on device for determining output dimensions # Prepare a single padded frame on device for determining output dimensions
def prepare_frame(idx): def prepare_frame(idx):

View File

@ -11,7 +11,7 @@ class ImageCompare(IO.ComfyNode):
def define_schema(cls): def define_schema(cls):
return IO.Schema( return IO.Schema(
node_id="ImageCompare", node_id="ImageCompare",
display_name="Image Compare", display_name="Compare Images",
description="Compares two images side by side with a slider.", description="Compares two images side by side with a slider.",
category="image", category="image",
essentials_category="Image Tools", essentials_category="Image Tools",

View File

@ -24,7 +24,7 @@ class ImageCrop(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="ImageCrop", node_id="ImageCrop",
search_aliases=["trim"], search_aliases=["trim"],
display_name="Image Crop (Deprecated)", display_name="Crop Image (DEPRECATED)",
category="image/transform", category="image/transform",
is_deprecated=True, is_deprecated=True,
essentials_category="Image Tools", essentials_category="Image Tools",
@ -56,7 +56,7 @@ class ImageCropV2(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="ImageCropV2", node_id="ImageCropV2",
search_aliases=["trim"], search_aliases=["trim"],
display_name="Image Crop", display_name="Crop Image",
category="image/transform", category="image/transform",
essentials_category="Image Tools", essentials_category="Image Tools",
has_intermediate_output=True, has_intermediate_output=True,
@ -109,6 +109,7 @@ class RepeatImageBatch(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="RepeatImageBatch", node_id="RepeatImageBatch",
search_aliases=["duplicate image", "clone image"], search_aliases=["duplicate image", "clone image"],
display_name="Repeat Image Batch",
category="image/batch", category="image/batch",
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),
@ -131,6 +132,7 @@ class ImageFromBatch(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="ImageFromBatch", node_id="ImageFromBatch",
search_aliases=["select image", "pick from batch", "extract image"], search_aliases=["select image", "pick from batch", "extract image"],
display_name="Get Image from Batch",
category="image/batch", category="image/batch",
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),
@ -157,7 +159,8 @@ class ImageAddNoise(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="ImageAddNoise", node_id="ImageAddNoise",
search_aliases=["film grain"], search_aliases=["film grain"],
category="image", display_name="Add Noise to Image",
category="image/postprocessing",
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),
IO.Int.Input( IO.Int.Input(
@ -259,7 +262,7 @@ class ImageStitch(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="ImageStitch", node_id="ImageStitch",
search_aliases=["combine images", "join images", "concatenate images", "side by side"], search_aliases=["combine images", "join images", "concatenate images", "side by side"],
display_name="Image Stitch", display_name="Stitch Images",
description="Stitches image2 to image1 in the specified direction.\n" description="Stitches image2 to image1 in the specified direction.\n"
"If image2 is not provided, returns image1 unchanged.\n" "If image2 is not provided, returns image1 unchanged.\n"
"Optional spacing can be added between images.", "Optional spacing can be added between images.",
@ -434,6 +437,7 @@ class ResizeAndPadImage(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="ResizeAndPadImage", node_id="ResizeAndPadImage",
search_aliases=["fit to size"], search_aliases=["fit to size"],
display_name="Resize And Pad Image",
category="image/transform", category="image/transform",
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),
@ -485,6 +489,7 @@ class SaveSVGNode(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="SaveSVGNode", node_id="SaveSVGNode",
search_aliases=["export vector", "save vector graphics"], search_aliases=["export vector", "save vector graphics"],
display_name="Save SVG",
description="Save SVG files on disk.", description="Save SVG files on disk.",
category="image/save", category="image/save",
inputs=[ inputs=[
@ -591,7 +596,7 @@ class ImageRotate(IO.ComfyNode):
def define_schema(cls): def define_schema(cls):
return IO.Schema( return IO.Schema(
node_id="ImageRotate", node_id="ImageRotate",
display_name="Image Rotate", display_name="Rotate Image",
search_aliases=["turn", "flip orientation"], search_aliases=["turn", "flip orientation"],
category="image/transform", category="image/transform",
essentials_category="Image Tools", essentials_category="Image Tools",
@ -624,6 +629,7 @@ class ImageFlip(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="ImageFlip", node_id="ImageFlip",
search_aliases=["mirror", "reflect"], search_aliases=["mirror", "reflect"],
display_name="Flip Image",
category="image/transform", category="image/transform",
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),
@ -650,6 +656,7 @@ class ImageScaleToMaxDimension(IO.ComfyNode):
def define_schema(cls): def define_schema(cls):
return IO.Schema( return IO.Schema(
node_id="ImageScaleToMaxDimension", node_id="ImageScaleToMaxDimension",
display_name="Scale Image to Max Dimension",
category="image/upscaling", category="image/upscaling",
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),
@ -709,7 +716,7 @@ class SplitImageToTileList(IO.ComfyNode):
def get_grid_coords(width, height, tile_width, tile_height, overlap): def get_grid_coords(width, height, tile_width, tile_height, overlap):
coords = [] coords = []
stride_x = round(max(tile_width * 0.25, tile_width - overlap)) stride_x = round(max(tile_width * 0.25, tile_width - overlap))
stride_y = round(max(tile_width * 0.25, tile_height - overlap)) stride_y = round(max(tile_height * 0.25, tile_height - overlap))
y = 0 y = 0
while y < height: while y < height:

View File

@ -147,7 +147,6 @@ class LTXVEmptyLatentAudio(io.ComfyNode):
z_channels = audio_vae.latent_channels z_channels = audio_vae.latent_channels
audio_freq = audio_vae.first_stage_model.latent_frequency_bins audio_freq = audio_vae.first_stage_model.latent_frequency_bins
sampling_rate = int(audio_vae.first_stage_model.sample_rate)
num_audio_latents = audio_vae.first_stage_model.num_of_latents_from_frames(frames_number, frame_rate) num_audio_latents = audio_vae.first_stage_model.num_of_latents_from_frames(frames_number, frame_rate)
@ -159,7 +158,6 @@ class LTXVEmptyLatentAudio(io.ComfyNode):
return io.NodeOutput( return io.NodeOutput(
{ {
"samples": audio_latents, "samples": audio_latents,
"sample_rate": sampling_rate,
"type": "audio", "type": "audio",
} }
) )

View File

@ -80,7 +80,8 @@ class ImageCompositeMasked(IO.ComfyNode):
def define_schema(cls): def define_schema(cls):
return IO.Schema( return IO.Schema(
node_id="ImageCompositeMasked", node_id="ImageCompositeMasked",
search_aliases=["paste image", "overlay", "layer"], search_aliases=["overlay", "layer", "paste image", "images composition"],
display_name="Image Composite Masked",
category="image", category="image",
inputs=[ inputs=[
IO.Image.Input("destination"), IO.Image.Input("destination"),
@ -201,6 +202,7 @@ class InvertMask(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="InvertMask", node_id="InvertMask",
search_aliases=["reverse mask", "flip mask"], search_aliases=["reverse mask", "flip mask"],
display_name="Invert Mask",
category="mask", category="mask",
inputs=[ inputs=[
IO.Mask.Input("mask"), IO.Mask.Input("mask"),
@ -222,6 +224,7 @@ class CropMask(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="CropMask", node_id="CropMask",
search_aliases=["cut mask", "extract mask region", "mask slice"], search_aliases=["cut mask", "extract mask region", "mask slice"],
display_name="Crop Mask",
category="mask", category="mask",
inputs=[ inputs=[
IO.Mask.Input("mask"), IO.Mask.Input("mask"),
@ -247,7 +250,8 @@ class MaskComposite(IO.ComfyNode):
def define_schema(cls): def define_schema(cls):
return IO.Schema( return IO.Schema(
node_id="MaskComposite", node_id="MaskComposite",
search_aliases=["combine masks", "blend masks", "layer masks"], search_aliases=["combine masks", "blend masks", "layer masks", "masks composition"],
display_name="Combine Masks",
category="mask", category="mask",
inputs=[ inputs=[
IO.Mask.Input("destination"), IO.Mask.Input("destination"),
@ -298,6 +302,7 @@ class FeatherMask(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="FeatherMask", node_id="FeatherMask",
search_aliases=["soft edge mask", "blur mask edges", "gradient mask edge"], search_aliases=["soft edge mask", "blur mask edges", "gradient mask edge"],
display_name="Feather Mask",
category="mask", category="mask",
inputs=[ inputs=[
IO.Mask.Input("mask"), IO.Mask.Input("mask"),

View File

@ -59,7 +59,8 @@ class ImageRGBToYUV(io.ComfyNode):
return io.Schema( return io.Schema(
node_id="ImageRGBToYUV", node_id="ImageRGBToYUV",
search_aliases=["color space conversion"], search_aliases=["color space conversion"],
category="image/batch", display_name="Image RGB to YUV",
category="image/color",
inputs=[ inputs=[
io.Image.Input("image"), io.Image.Input("image"),
], ],
@ -81,7 +82,8 @@ class ImageYUVToRGB(io.ComfyNode):
return io.Schema( return io.Schema(
node_id="ImageYUVToRGB", node_id="ImageYUVToRGB",
search_aliases=["color space conversion"], search_aliases=["color space conversion"],
category="image/batch", display_name="Image YUV to RGB",
category="image/color",
inputs=[ inputs=[
io.Image.Input("Y"), io.Image.Input("Y"),
io.Image.Input("U"), io.Image.Input("U"),

View File

@ -20,7 +20,8 @@ class Blend(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="ImageBlend", node_id="ImageBlend",
display_name="Image Blend", search_aliases=["mix images"],
display_name="Blend Images",
category="image/postprocessing", category="image/postprocessing",
essentials_category="Image Tools", essentials_category="Image Tools",
inputs=[ inputs=[
@ -224,6 +225,7 @@ class ImageScaleToTotalPixels(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="ImageScaleToTotalPixels", node_id="ImageScaleToTotalPixels",
display_name="Scale Image to Total Pixels",
category="image/upscaling", category="image/upscaling",
inputs=[ inputs=[
io.Image.Input("image"), io.Image.Input("image"),
@ -568,7 +570,7 @@ class BatchImagesNode(io.ComfyNode):
return io.Schema( return io.Schema(
node_id="BatchImagesNode", node_id="BatchImagesNode",
display_name="Batch Images", display_name="Batch Images",
category="image", category="image/batch",
essentials_category="Image Tools", essentials_category="Image Tools",
search_aliases=["batch", "image batch", "batch images", "combine images", "merge images", "stack images"], search_aliases=["batch", "image batch", "batch images", "combine images", "merge images", "stack images"],
inputs=[ inputs=[
@ -666,12 +668,13 @@ class ColorTransfer(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="ColorTransfer", node_id="ColorTransfer",
display_name="Color Transfer",
category="image/postprocessing", category="image/postprocessing",
description="Match the colors of one image to another using various algorithms.", description="Match the colors of one image to another using various algorithms.",
search_aliases=["color match", "color grading", "color correction", "match colors", "color transform", "mkl", "reinhard", "histogram"], search_aliases=["color match", "color grading", "color correction", "match colors", "color transform", "mkl", "reinhard", "histogram"],
inputs=[ inputs=[
io.Image.Input("image_target", tooltip="Image(s) to apply the color transform to."), io.Image.Input("image_target", tooltip="Image(s) to apply the color transform to."),
io.Image.Input("image_ref", optional=True, tooltip="Reference image(s) to match colors to. If not provided, processing is skipped"), io.Image.Input("image_ref", tooltip="Reference image(s) to match colors to."),
io.Combo.Input("method", options=['reinhard_lab', 'mkl_lab', 'histogram'],), io.Combo.Input("method", options=['reinhard_lab', 'mkl_lab', 'histogram'],),
io.DynamicCombo.Input("source_stats", io.DynamicCombo.Input("source_stats",
tooltip="per_frame: each frame matched to image_ref individually. uniform: pool stats across all source frames as baseline, match to image_ref. target_frame: use one chosen frame as the baseline for the transform to image_ref, applied uniformly to all frames (preserves relative differences)", tooltip="per_frame: each frame matched to image_ref individually. uniform: pool stats across all source frames as baseline, match to image_ref. target_frame: use one chosen frame as the baseline for the transform to image_ref, applied uniformly to all frames (preserves relative differences)",

View File

@ -9,7 +9,8 @@ class String(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="PrimitiveString", node_id="PrimitiveString",
display_name="String", search_aliases=["text", "string", "text box", "prompt"],
display_name="Text String",
category="utils/primitive", category="utils/primitive",
inputs=[ inputs=[
io.String.Input("value"), io.String.Input("value"),
@ -27,7 +28,8 @@ class StringMultiline(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="PrimitiveStringMultiline", node_id="PrimitiveStringMultiline",
display_name="String (Multiline)", search_aliases=["text", "string", "text multiline", "string multiline", "text box", "prompt"],
display_name="Text String (Multiline)",
category="utils/primitive", category="utils/primitive",
essentials_category="Basics", essentials_category="Basics",
inputs=[ inputs=[
@ -49,7 +51,7 @@ class Int(io.ComfyNode):
display_name="Int", display_name="Int",
category="utils/primitive", category="utils/primitive",
inputs=[ inputs=[
io.Int.Input("value", min=-sys.maxsize, max=sys.maxsize, control_after_generate=True), io.Int.Input("value", min=-sys.maxsize, max=sys.maxsize, control_after_generate=io.ControlAfterGenerate.fixed),
], ],
outputs=[io.Int.Output()], outputs=[io.Int.Output()],
) )

View File

@ -10,9 +10,9 @@ class StringConcatenate(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="StringConcatenate", node_id="StringConcatenate",
display_name="Text Concatenate", search_aliases=["concatenate", "text concat", "join text", "merge text", "combine strings", "string concat", "append text", "combine text"],
category="utils/string", display_name="Concatenate Text",
search_aliases=["Concatenate", "text concat", "join text", "merge text", "combine strings", "concat", "concatenate", "append text", "combine text", "string"], category="text",
inputs=[ inputs=[
io.String.Input("string_a", multiline=True), io.String.Input("string_a", multiline=True),
io.String.Input("string_b", multiline=True), io.String.Input("string_b", multiline=True),
@ -33,9 +33,9 @@ class StringSubstring(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="StringSubstring", node_id="StringSubstring",
search_aliases=["Substring", "extract text", "text portion"], search_aliases=["substring", "extract text", "text portion"],
display_name="Text Substring", display_name="Substring",
category="utils/string", category="text",
inputs=[ inputs=[
io.String.Input("string", multiline=True), io.String.Input("string", multiline=True),
io.Int.Input("start"), io.Int.Input("start"),
@ -58,7 +58,7 @@ class StringLength(io.ComfyNode):
node_id="StringLength", node_id="StringLength",
search_aliases=["character count", "text size", "string length"], search_aliases=["character count", "text size", "string length"],
display_name="Text Length", display_name="Text Length",
category="utils/string", category="text",
inputs=[ inputs=[
io.String.Input("string", multiline=True), io.String.Input("string", multiline=True),
], ],
@ -77,9 +77,9 @@ class CaseConverter(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="CaseConverter", node_id="CaseConverter",
search_aliases=["Case Converter", "text case", "uppercase", "lowercase", "capitalize"], search_aliases=["case converter", "text case", "uppercase", "lowercase", "capitalize"],
display_name="Text Case Converter", display_name="Convert Text Case",
category="utils/string", category="text",
inputs=[ inputs=[
io.String.Input("string", multiline=True), io.String.Input("string", multiline=True),
io.Combo.Input("mode", options=["UPPERCASE", "lowercase", "Capitalize", "Title Case"]), io.Combo.Input("mode", options=["UPPERCASE", "lowercase", "Capitalize", "Title Case"]),
@ -110,9 +110,9 @@ class StringTrim(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="StringTrim", node_id="StringTrim",
search_aliases=["Trim", "clean whitespace", "remove whitespace", "strip"], search_aliases=["trim", "clean whitespace", "remove whitespace", "remove spaces","strip"],
display_name="Text Trim", display_name="Trim Text",
category="utils/string", category="text",
inputs=[ inputs=[
io.String.Input("string", multiline=True), io.String.Input("string", multiline=True),
io.Combo.Input("mode", options=["Both", "Left", "Right"]), io.Combo.Input("mode", options=["Both", "Left", "Right"]),
@ -141,9 +141,9 @@ class StringReplace(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="StringReplace", node_id="StringReplace",
search_aliases=["Replace", "find and replace", "substitute", "swap text"], search_aliases=["replace", "find and replace", "substitute", "swap text"],
display_name="Text Replace", display_name="Replace Text",
category="utils/string", category="text",
inputs=[ inputs=[
io.String.Input("string", multiline=True), io.String.Input("string", multiline=True),
io.String.Input("find", multiline=True), io.String.Input("find", multiline=True),
@ -164,9 +164,9 @@ class StringContains(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="StringContains", node_id="StringContains",
search_aliases=["Contains", "text includes", "string includes"], search_aliases=["contains", "text includes", "string includes"],
display_name="Text Contains", display_name="Contains Text",
category="utils/string", category="text",
inputs=[ inputs=[
io.String.Input("string", multiline=True), io.String.Input("string", multiline=True),
io.String.Input("substring", multiline=True), io.String.Input("substring", multiline=True),
@ -192,9 +192,9 @@ class StringCompare(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="StringCompare", node_id="StringCompare",
search_aliases=["Compare", "text match", "string equals", "starts with", "ends with"], search_aliases=["compare", "text match", "string equals", "starts with", "ends with"],
display_name="Text Compare", display_name="Compare Text",
category="utils/string", category="text",
inputs=[ inputs=[
io.String.Input("string_a", multiline=True), io.String.Input("string_a", multiline=True),
io.String.Input("string_b", multiline=True), io.String.Input("string_b", multiline=True),
@ -228,9 +228,9 @@ class RegexMatch(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="RegexMatch", node_id="RegexMatch",
search_aliases=["Regex Match", "regex", "pattern match", "text contains", "string match"], search_aliases=["regex match", "regex", "pattern match", "text contains", "string match"],
display_name="Text Match", display_name="Match Text",
category="utils/string", category="text",
inputs=[ inputs=[
io.String.Input("string", multiline=True), io.String.Input("string", multiline=True),
io.String.Input("regex_pattern", multiline=True), io.String.Input("regex_pattern", multiline=True),
@ -269,9 +269,9 @@ class RegexExtract(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="RegexExtract", node_id="RegexExtract",
search_aliases=["Regex Extract", "regex", "pattern extract", "text parser", "parse text"], search_aliases=["regex extract", "regex", "pattern extract", "text parser", "parse text"],
display_name="Text Extract Substring", display_name="Extract Text",
category="utils/string", category="text",
inputs=[ inputs=[
io.String.Input("string", multiline=True), io.String.Input("string", multiline=True),
io.String.Input("regex_pattern", multiline=True), io.String.Input("regex_pattern", multiline=True),
@ -344,9 +344,9 @@ class RegexReplace(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="RegexReplace", node_id="RegexReplace",
search_aliases=["Regex Replace", "regex", "pattern replace", "regex replace", "substitution"], search_aliases=["regex replace", "regex", "pattern replace", "substitution"],
display_name="Text Replace (Regex)", display_name="Replace Text (Regex)",
category="utils/string", category="text",
description="Find and replace text using regex patterns.", description="Find and replace text using regex patterns.",
inputs=[ inputs=[
io.String.Input("string", multiline=True), io.String.Input("string", multiline=True),
@ -381,8 +381,8 @@ class JsonExtractString(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="JsonExtractString", node_id="JsonExtractString",
display_name="Extract String from JSON", display_name="Extract Text from JSON",
category="utils/string", category="text",
search_aliases=["json", "extract json", "parse json", "json value", "read json"], search_aliases=["json", "extract json", "parse json", "json value", "read json"],
inputs=[ inputs=[
io.String.Input("json_string", multiline=True), io.String.Input("json_string", multiline=True),

View File

@ -32,6 +32,8 @@ class TextGenerate(io.ComfyNode):
io.Clip.Input("clip"), io.Clip.Input("clip"),
io.String.Input("prompt", multiline=True, dynamic_prompts=True, default=""), io.String.Input("prompt", multiline=True, dynamic_prompts=True, default=""),
io.Image.Input("image", optional=True), io.Image.Input("image", optional=True),
io.Image.Input("video", optional=True, tooltip="Video frames as image batch. Assumed to be 24 FPS; subsampled to 1 FPS internally."),
io.Audio.Input("audio", optional=True),
io.Int.Input("max_length", default=256, min=1, max=2048), io.Int.Input("max_length", default=256, min=1, max=2048),
io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"), io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"),
io.Boolean.Input("thinking", optional=True, default=False, tooltip="Operate in thinking mode if the model supports it."), io.Boolean.Input("thinking", optional=True, default=False, tooltip="Operate in thinking mode if the model supports it."),
@ -43,9 +45,9 @@ class TextGenerate(io.ComfyNode):
) )
@classmethod @classmethod
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True) -> io.NodeOutput: def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True, video=None, audio=None) -> io.NodeOutput:
tokens = clip.tokenize(prompt, image=image, skip_template=not use_default_template, min_length=1, thinking=thinking) tokens = clip.tokenize(prompt, image=image, skip_template=not use_default_template, min_length=1, thinking=thinking, video=video, audio=audio)
# Get sampling parameters from dynamic combo # Get sampling parameters from dynamic combo
do_sample = sampling_mode.get("sampling_mode") == "on" do_sample = sampling_mode.get("sampling_mode") == "on"
@ -70,7 +72,8 @@ class TextGenerate(io.ComfyNode):
seed=seed seed=seed
) )
generated_text = clip.decode(generated_ids, skip_special_tokens=True) generated_text = clip.decode(generated_ids)
return io.NodeOutput(generated_text) return io.NodeOutput(generated_text)
@ -161,12 +164,12 @@ class TextGenerateLTX2Prompt(TextGenerate):
) )
@classmethod @classmethod
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True) -> io.NodeOutput: def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True, video=None, audio=None) -> io.NodeOutput:
if image is None: if image is None:
formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n" formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
else: else:
formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n" formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image, thinking, use_default_template) return super().execute(clip, formatted_prompt, max_length, sampling_mode, image=image, thinking=thinking, use_default_template=use_default_template, video=video, audio=audio)
class TextgenExtension(ComfyExtension): class TextgenExtension(ComfyExtension):

View File

@ -17,7 +17,8 @@ class SaveWEBM(io.ComfyNode):
return io.Schema( return io.Schema(
node_id="SaveWEBM", node_id="SaveWEBM",
search_aliases=["export webm"], search_aliases=["export webm"],
category="image/video", display_name="Save WEBM",
category="video",
is_experimental=True, is_experimental=True,
inputs=[ inputs=[
io.Image.Input("images"), io.Image.Input("images"),
@ -72,7 +73,7 @@ class SaveVideo(io.ComfyNode):
node_id="SaveVideo", node_id="SaveVideo",
search_aliases=["export video"], search_aliases=["export video"],
display_name="Save Video", display_name="Save Video",
category="image/video", category="video",
essentials_category="Basics", essentials_category="Basics",
description="Saves the input images to your ComfyUI output directory.", description="Saves the input images to your ComfyUI output directory.",
inputs=[ inputs=[
@ -121,7 +122,7 @@ class CreateVideo(io.ComfyNode):
node_id="CreateVideo", node_id="CreateVideo",
search_aliases=["images to video"], search_aliases=["images to video"],
display_name="Create Video", display_name="Create Video",
category="image/video", category="video",
description="Create a video from images.", description="Create a video from images.",
inputs=[ inputs=[
io.Image.Input("images", tooltip="The images to create a video from."), io.Image.Input("images", tooltip="The images to create a video from."),
@ -146,7 +147,7 @@ class GetVideoComponents(io.ComfyNode):
node_id="GetVideoComponents", node_id="GetVideoComponents",
search_aliases=["extract frames", "split video", "video to images", "demux"], search_aliases=["extract frames", "split video", "video to images", "demux"],
display_name="Get Video Components", display_name="Get Video Components",
category="image/video", category="video",
description="Extracts all components from a video: frames, audio, and framerate.", description="Extracts all components from a video: frames, audio, and framerate.",
inputs=[ inputs=[
io.Video.Input("video", tooltip="The video to extract components from."), io.Video.Input("video", tooltip="The video to extract components from."),
@ -174,7 +175,7 @@ class LoadVideo(io.ComfyNode):
node_id="LoadVideo", node_id="LoadVideo",
search_aliases=["import video", "open video", "video file"], search_aliases=["import video", "open video", "video file"],
display_name="Load Video", display_name="Load Video",
category="image/video", category="video",
essentials_category="Basics", essentials_category="Basics",
inputs=[ inputs=[
io.Combo.Input("file", options=sorted(files), upload=io.UploadType.video), io.Combo.Input("file", options=sorted(files), upload=io.UploadType.video),
@ -216,7 +217,7 @@ class VideoSlice(io.ComfyNode):
"frame load cap", "frame load cap",
"start time", "start time",
], ],
category="image/video", category="video",
essentials_category="Video Tools", essentials_category="Video Tools",
inputs=[ inputs=[
io.Video.Input("video"), io.Video.Input("video"),

483
comfy_extras/nodes_void.py Normal file
View File

@ -0,0 +1,483 @@
import logging
import torch
import comfy
import comfy.model_management
import comfy.model_patcher
import comfy.samplers
import comfy.utils
import folder_paths
import node_helpers
import nodes
from comfy.utils import model_trange as trange
from comfy_api.latest import ComfyExtension, io
from torchvision.models.optical_flow import raft_large
from typing_extensions import override
from comfy_extras.void_noise_warp import RaftOpticalFlow, get_noise_from_video
OpticalFlow = io.Custom("OPTICAL_FLOW")
TEMPORAL_COMPRESSION = 4
PATCH_SIZE_T = 2
def _valid_void_length(length: int) -> int:
"""Round ``length`` down to a value that produces an even latent_t.
VOID / CogVideoX-Fun-V1.5 uses patch_size_t=2, so the VAE-encoded latent
must have an even temporal dimension. If latent_t is odd, the transformer
pad_to_patch_size circular-wraps an extra latent frame onto the end; after
the post-transformer crop the last real latent frame has been influenced
by the wrapped phantom frame, producing visible jitter and "disappearing"
subjects near the end of the decoded video. Rounding down fixes this.
"""
latent_t = ((length - 1) // TEMPORAL_COMPRESSION) + 1
if latent_t % PATCH_SIZE_T == 0:
return length
# Round latent_t down to the nearest multiple of PATCH_SIZE_T, then invert
# the ((length - 1) // TEMPORAL_COMPRESSION) + 1 formula. Floor at 1 frame
# so we never return a non-positive length.
target_latent_t = max(PATCH_SIZE_T, (latent_t // PATCH_SIZE_T) * PATCH_SIZE_T)
return (target_latent_t - 1) * TEMPORAL_COMPRESSION + 1
class OpticalFlowLoader(io.ComfyNode):
"""Load an optical flow model from ``models/optical_flow/``.
Only torchvision's RAFT-large format is recognized today (the model used
by VOIDWarpedNoise). The checkpoint must be placed under
``models/optical_flow/`` ComfyUI never downloads optical-flow weights
at runtime.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="OpticalFlowLoader",
display_name="Load Optical Flow Model",
category="loaders",
inputs=[
io.Combo.Input(
"model_name",
options=folder_paths.get_filename_list("optical_flow"),
tooltip=(
"Optical flow model to load. Files must be placed in the "
"'optical_flow' folder. Today only torchvision's "
"raft_large.pth is supported."
),
),
],
outputs=[
OpticalFlow.Output(),
],
)
@classmethod
def execute(cls, model_name) -> io.NodeOutput:
model_path = folder_paths.get_full_path_or_raise("optical_flow", model_name)
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
has_raft_keys = (
any(k.startswith("feature_encoder.") for k in sd)
and any(k.startswith("context_encoder.") for k in sd)
and any(k.startswith("update_block.") for k in sd)
)
if not has_raft_keys:
raise ValueError(
"Unrecognized optical flow model format: expected a torchvision "
"RAFT-large state dict with 'feature_encoder.', 'context_encoder.' "
"and 'update_block.' prefixes."
)
model = raft_large(weights=None, progress=False)
model.load_state_dict(sd)
model.eval().to(torch.float32)
patcher = comfy.model_patcher.ModelPatcher(
model,
load_device=comfy.model_management.get_torch_device(),
offload_device=comfy.model_management.unet_offload_device(),
)
return io.NodeOutput(patcher)
class VOIDQuadmaskPreprocess(io.ComfyNode):
"""Preprocess a quadmask video for VOID inpainting.
Quantizes mask values to four semantic levels, inverts, and normalizes:
0 -> primary object to remove
63 -> overlap of primary + affected
127 -> affected region (interactions)
255 -> background (keep)
After inversion and normalization, the output mask has values in [0, 1]
with four discrete levels: 1.0 (remove), ~0.75, ~0.50, 0.0 (keep).
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="VOIDQuadmaskPreprocess",
category="mask/video",
inputs=[
io.Mask.Input("mask"),
io.Int.Input("dilate_width", default=0, min=0, max=50, step=1,
tooltip="Dilation radius for the primary mask region (0 = no dilation)"),
],
outputs=[
io.Mask.Output(display_name="quadmask"),
],
)
@classmethod
def execute(cls, mask, dilate_width=0) -> io.NodeOutput:
m = mask.clone()
if m.max() <= 1.0:
m = m * 255.0
if dilate_width > 0 and m.ndim >= 3:
binary = (m < 128).float()
kernel_size = dilate_width * 2 + 1
if binary.ndim == 3:
binary = binary.unsqueeze(1)
dilated = torch.nn.functional.max_pool2d(
binary, kernel_size=kernel_size, stride=1, padding=dilate_width
)
if dilated.ndim == 4:
dilated = dilated.squeeze(1)
m = torch.where(dilated > 0.5, torch.zeros_like(m), m)
m = torch.where(m <= 31, torch.zeros_like(m), m)
m = torch.where((m > 31) & (m <= 95), torch.full_like(m, 63), m)
m = torch.where((m > 95) & (m <= 191), torch.full_like(m, 127), m)
m = torch.where(m > 191, torch.full_like(m, 255), m)
m = (255.0 - m) / 255.0
return io.NodeOutput(m)
class VOIDInpaintConditioning(io.ComfyNode):
"""Build VOID inpainting conditioning for CogVideoX.
Encodes the processed quadmask and masked source video through the VAE,
producing a 32-channel concat conditioning (16ch mask + 16ch masked video)
that gets concatenated with the 16ch noise latent by the model.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="VOIDInpaintConditioning",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Image.Input("video", tooltip="Source video frames [T, H, W, 3]"),
io.Mask.Input("quadmask", tooltip="Preprocessed quadmask from VOIDQuadmaskPreprocess [T, H, W]"),
io.Int.Input("width", default=672, min=16, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("height", default=384, min=16, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("length", default=45, min=1, max=nodes.MAX_RESOLUTION, step=1,
tooltip="Number of pixel frames to process. For CogVideoX-Fun-V1.5 "
"(patch_size_t=2), latent_t must be even — lengths that "
"produce odd latent_t are rounded down (e.g. 49 → 45)."),
io.Int.Input("batch_size", default=1, min=1, max=64),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
@classmethod
def execute(cls, positive, negative, vae, video, quadmask,
width, height, length, batch_size) -> io.NodeOutput:
adjusted_length = _valid_void_length(length)
if adjusted_length != length:
logging.warning(
"VOIDInpaintConditioning: rounding length %d down to %d so that "
"latent_t is even (required by CogVideoX-Fun-V1.5 patch_size_t=2). "
"Using odd latent_t causes the last frame to be corrupted by "
"circular padding.", length, adjusted_length,
)
length = adjusted_length
latent_t = ((length - 1) // TEMPORAL_COMPRESSION) + 1
latent_h = height // 8
latent_w = width // 8
vid = video[:length]
vid = comfy.utils.common_upscale(
vid.movedim(-1, 1), width, height, "bilinear", "center"
).movedim(1, -1)
qm = quadmask[:length]
if qm.ndim == 3:
qm = qm.unsqueeze(-1)
qm = comfy.utils.common_upscale(
qm.movedim(-1, 1), width, height, "bilinear", "center"
).movedim(1, -1)
if qm.ndim == 4 and qm.shape[-1] == 1:
qm = qm.squeeze(-1)
mask_condition = qm
if mask_condition.ndim == 3:
mask_condition_3ch = mask_condition.unsqueeze(-1).expand(-1, -1, -1, 3)
else:
mask_condition_3ch = mask_condition
inverted_mask_3ch = 1.0 - mask_condition_3ch
masked_video = vid[:, :, :, :3] * (1.0 - mask_condition_3ch)
mask_latents = vae.encode(inverted_mask_3ch)
masked_video_latents = vae.encode(masked_video)
def _match_temporal(lat, target_t):
if lat.shape[2] > target_t:
return lat[:, :, :target_t]
elif lat.shape[2] < target_t:
pad = target_t - lat.shape[2]
return torch.cat([lat, lat[:, :, -1:].repeat(1, 1, pad, 1, 1)], dim=2)
return lat
mask_latents = _match_temporal(mask_latents, latent_t)
masked_video_latents = _match_temporal(masked_video_latents, latent_t)
inpaint_latents = torch.cat([mask_latents, masked_video_latents], dim=1)
# No explicit scaling needed here: the model's CogVideoX.concat_cond()
# applies process_latent_in (×latent_format.scale_factor) to each 16-ch
# block of the stored conditioning. For 5b-class checkpoints (incl. the
# VOID/CogVideoX-Fun-V1.5 inpainting model) that scale_factor is auto-
# selected as 0.7 in supported_models.CogVideoX_T2V, which matches the
# diffusers vae/config.json scaling_factor VOID was trained with.
positive = node_helpers.conditioning_set_values(
positive, {"concat_latent_image": inpaint_latents}
)
negative = node_helpers.conditioning_set_values(
negative, {"concat_latent_image": inpaint_latents}
)
noise_latent = torch.zeros(
[batch_size, 16, latent_t, latent_h, latent_w],
device=comfy.model_management.intermediate_device()
)
return io.NodeOutput(positive, negative, {"samples": noise_latent})
class VOIDWarpedNoise(io.ComfyNode):
"""Generate optical-flow warped noise for VOID Pass 2 refinement.
Takes the Pass 1 output video and produces temporally-correlated noise
by warping Gaussian noise along optical flow vectors. This noise is used
as the initial latent for Pass 2, resulting in better temporal consistency.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="VOIDWarpedNoise",
category="latent/video",
inputs=[
OpticalFlow.Input(
"optical_flow",
tooltip="Optical flow model from OpticalFlowLoader (RAFT-large).",
),
io.Image.Input("video", tooltip="Pass 1 output video frames [T, H, W, 3]"),
io.Int.Input("width", default=672, min=16, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("height", default=384, min=16, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("length", default=45, min=1, max=nodes.MAX_RESOLUTION, step=1,
tooltip="Number of pixel frames. Rounded down to make latent_t "
"even (patch_size_t=2 requirement), e.g. 49 → 45."),
io.Int.Input("batch_size", default=1, min=1, max=64),
],
outputs=[
io.Latent.Output(display_name="warped_noise"),
],
)
@classmethod
def execute(cls, optical_flow, video, width, height, length, batch_size) -> io.NodeOutput:
adjusted_length = _valid_void_length(length)
if adjusted_length != length:
logging.warning(
"VOIDWarpedNoise: rounding length %d down to %d so that "
"latent_t is even (required by CogVideoX-Fun-V1.5 patch_size_t=2).",
length, adjusted_length,
)
length = adjusted_length
latent_t = ((length - 1) // TEMPORAL_COMPRESSION) + 1
latent_h = height // 8
latent_w = width // 8
# RAFT + noise warp is real compute, not an "intermediate" buffer, so
# we want the actual torch device (CUDA/MPS). The final latent is
# moved back to intermediate_device() before returning to match the
# rest of the ComfyUI pipeline.
device = comfy.model_management.get_torch_device()
comfy.model_management.load_model_gpu(optical_flow)
raft = RaftOpticalFlow(optical_flow.model, device=device)
vid = video[:length].to(device)
vid = comfy.utils.common_upscale(
vid.movedim(-1, 1), width, height, "bilinear", "center"
).movedim(1, -1)
vid_uint8 = (vid.clamp(0, 1) * 255).to(torch.uint8)
FRAME = 2**-1
FLOW = 2**3
LATENT_SCALE = 8
warped = get_noise_from_video(
vid_uint8,
raft,
noise_channels=16,
resize_frames=FRAME,
resize_flow=FLOW,
downscale_factor=round(FRAME * FLOW) * LATENT_SCALE,
device=device,
)
if warped.shape[0] != latent_t:
indices = torch.linspace(0, warped.shape[0] - 1, latent_t,
device=device).long()
warped = warped[indices]
if warped.shape[1] != latent_h or warped.shape[2] != latent_w:
# (T, H, W, C) → (T, C, H, W) → bilinear resize → back
warped = warped.permute(0, 3, 1, 2)
warped = torch.nn.functional.interpolate(
warped, size=(latent_h, latent_w),
mode="bilinear", align_corners=False,
)
warped = warped.permute(0, 2, 3, 1)
# (T, H, W, C) → (B, C, T, H, W)
warped_tensor = warped.permute(3, 0, 1, 2).unsqueeze(0)
if batch_size > 1:
warped_tensor = warped_tensor.repeat(batch_size, 1, 1, 1, 1)
warped_tensor = warped_tensor.to(comfy.model_management.intermediate_device())
return io.NodeOutput({"samples": warped_tensor})
class Noise_FromLatent:
"""Wraps a pre-computed LATENT tensor as a NOISE source."""
def __init__(self, latent_dict):
self.seed = 0
self._samples = latent_dict["samples"]
def generate_noise(self, input_latent):
return self._samples.clone().cpu()
class VOIDWarpedNoiseSource(io.ComfyNode):
"""Convert a LATENT (e.g. from VOIDWarpedNoise) into a NOISE source
for use with SamplerCustomAdvanced."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="VOIDWarpedNoiseSource",
category="sampling/custom_sampling/noise",
inputs=[
io.Latent.Input("warped_noise",
tooltip="Warped noise latent from VOIDWarpedNoise"),
],
outputs=[io.Noise.Output()],
)
@classmethod
def execute(cls, warped_noise) -> io.NodeOutput:
return io.NodeOutput(Noise_FromLatent(warped_noise))
class VOID_DDIM(comfy.samplers.Sampler):
"""DDIM sampler for VOID inpainting models.
VOID was trained with the diffusers CogVideoXDDIMScheduler which operates in
alpha-space (input std 1). The standard KSampler applies noise_scaling that
multiplies by sqrt(1+sigma^2) 4500x, which is incompatible with VOID's
training. This sampler skips noise_scaling and implements the DDIM update rule
directly using sigma-to-alpha conversion.
"""
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
x = noise.to(torch.float32)
model_options = extra_args.get("model_options", {})
seed = extra_args.get("seed", None)
s_in = x.new_ones([x.shape[0]])
for i in trange(len(sigmas) - 1, disable=disable_pbar):
sigma = sigmas[i]
sigma_next = sigmas[i + 1]
denoised = model_wrap(x, sigma * s_in, model_options=model_options, seed=seed)
if callback is not None:
callback(i, denoised, x, len(sigmas) - 1)
if sigma_next == 0:
x = denoised
else:
alpha_t = 1.0 / (1.0 + sigma ** 2)
alpha_prev = 1.0 / (1.0 + sigma_next ** 2)
pred_eps = (x - (alpha_t ** 0.5) * denoised) / (1.0 - alpha_t) ** 0.5
x = (alpha_prev ** 0.5) * denoised + (1.0 - alpha_prev) ** 0.5 * pred_eps
return x
class VOIDSampler(io.ComfyNode):
"""VOID DDIM sampler for use with SamplerCustom / SamplerCustomAdvanced.
Required for VOID inpainting models. Implements the same DDIM loop that VOID
was trained with (diffusers CogVideoXDDIMScheduler), without the noise_scaling
that the standard KSampler applies. Use with RandomNoise or VOIDWarpedNoiseSource.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="VOIDSampler",
category="sampling/custom_sampling/samplers",
inputs=[],
outputs=[io.Sampler.Output()],
)
@classmethod
def execute(cls) -> io.NodeOutput:
return io.NodeOutput(VOID_DDIM())
get_sampler = execute
class VOIDExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
OpticalFlowLoader,
VOIDQuadmaskPreprocess,
VOIDInpaintConditioning,
VOIDWarpedNoise,
VOIDWarpedNoiseSource,
VOIDSampler,
]
async def comfy_entrypoint() -> VOIDExtension:
return VOIDExtension()

View File

@ -0,0 +1,494 @@
"""
Optical-flow-warped noise for VOID Pass 2 refinement.
Adapted from RyannDaGreat/CommonSource (MIT License, Ryan Burgert):
https://github.com/RyannDaGreat/CommonSource
- noise_warp.py (NoiseWarper / warp_xyωc / regaussianize / get_noise_from_video)
- raft.py (RaftOpticalFlow)
Only the code paths that ``comfy_extras/nodes_void.py::VOIDWarpedNoise`` actually
uses (torch THWC uint8 input, no background removal, no visualization, no disk
I/O, default warp/noise params) have been inlined. External ``rp`` utilities
have been replaced with equivalents from torch.nn.functional / einops. The
RAFT optical-flow model itself is loaded offline via ``OpticalFlowLoader`` in
``nodes_void.py`` and passed into ``get_noise_from_video`` by the caller; this
module never downloads weights at runtime.
"""
import logging
from typing import Optional
import torch
import torch.nn.functional as F
from einops import rearrange
import comfy.model_management
# ---------------------------------------------------------------------------
# Low-level torch image helpers (drop-in replacements for rp.torch_* primitives)
# ---------------------------------------------------------------------------
def _torch_resize_chw(image, size, interp, copy=True):
"""Resize a CHW tensor.
``size`` is either a scalar factor or a (h, w) tuple. ``interp`` is one
of ``"bilinear"``, ``"nearest"``, ``"area"``. When ``copy`` is False and
the requested size matches the input, returns the input tensor as is
(faster but callers must not mutate the result).
"""
if image.ndim != 3:
raise ValueError(
f"_torch_resize_chw expects a 3D CHW tensor, got shape {tuple(image.shape)}"
)
_, in_h, in_w = image.shape
if isinstance(size, (int, float)) and not isinstance(size, bool):
new_h = max(1, int(in_h * size))
new_w = max(1, int(in_w * size))
else:
new_h, new_w = size
if (new_h, new_w) == (in_h, in_w):
return image.clone() if copy else image
kwargs = {}
if interp in ("bilinear", "bicubic"):
kwargs["align_corners"] = False
out = F.interpolate(image[None], size=(new_h, new_w), mode=interp, **kwargs)[0]
return out
def _torch_remap_relative(image, dx, dy, interp="bilinear"):
"""Relative remap of a CHW image via ``F.grid_sample``.
Equivalent to ``rp.torch_remap_image(image, dx, dy, relative=True, interp=interp)``
for ``interp`` in {"bilinear", "nearest"}. Out-of-bounds samples are 0.
"""
if image.ndim != 3:
raise ValueError(
f"_torch_remap_relative expects a 3D CHW tensor, got shape {tuple(image.shape)}"
)
if dx.shape != dy.shape:
raise ValueError(
f"_torch_remap_relative: dx and dy must match, got {tuple(dx.shape)} vs {tuple(dy.shape)}"
)
_, h, w = image.shape
x_abs = dx + torch.arange(w, device=dx.device, dtype=dx.dtype)
y_abs = dy + torch.arange(h, device=dy.device, dtype=dy.dtype)[:, None]
x_norm = (x_abs / (w - 1)) * 2 - 1
y_norm = (y_abs / (h - 1)) * 2 - 1
grid = torch.stack([x_norm, y_norm], dim=-1)[None].to(image.dtype)
out = F.grid_sample(
image[None], grid, mode=interp, align_corners=True, padding_mode="zeros"
)[0]
return out
def _torch_scatter_add_relative(image, dx, dy):
"""Scatter-add a CHW image using relative floor-rounded (dx, dy) offsets.
Equivalent to ``rp.torch_scatter_add_image(image, dx, dy, relative=True,
interp='floor')``. Out-of-bounds targets are dropped.
"""
if image.ndim != 3:
raise ValueError(
f"_torch_scatter_add_relative expects a 3D CHW tensor, got shape {tuple(image.shape)}"
)
in_c, in_h, in_w = image.shape
if dx.shape != (in_h, in_w) or dy.shape != (in_h, in_w):
raise ValueError(
f"_torch_scatter_add_relative: dx/dy must be ({in_h}, {in_w}), "
f"got dx={tuple(dx.shape)} dy={tuple(dy.shape)}"
)
x = dx.long() + torch.arange(in_w, device=dx.device, dtype=torch.long)
y = dy.long() + torch.arange(in_h, device=dy.device, dtype=torch.long)[:, None]
valid = ((y >= 0) & (y < in_h) & (x >= 0) & (x < in_w)).reshape(-1)
indices = (y * in_w + x).reshape(-1)[valid]
flat_image = rearrange(image, "c h w -> (h w) c")[valid]
out = torch.zeros((in_h * in_w, in_c), dtype=image.dtype, device=image.device)
out.index_add_(0, indices, flat_image)
return rearrange(out, "(h w) c -> c h w", h=in_h, w=in_w)
# ---------------------------------------------------------------------------
# Noise warping primitives (ported from noise_warp.py)
# ---------------------------------------------------------------------------
def unique_pixels(image):
"""Find unique pixel values in a CHW tensor.
Returns ``(unique_colors [U, C], counts [U], index_matrix [H, W])`` where
``index_matrix[i, j]`` is the index of the unique color at that pixel.
"""
_, h, w = image.shape
flat = rearrange(image, "c h w -> (h w) c")
unique_colors, inverse_indices, counts = torch.unique(
flat, dim=0, return_inverse=True, return_counts=True, sorted=False,
)
index_matrix = rearrange(inverse_indices, "(h w) -> h w", h=h, w=w)
return unique_colors, counts, index_matrix
def sum_indexed_values(image, index_matrix):
"""For each unique index, sum the CHW image values at its pixels."""
_, h, w = image.shape
u = int(index_matrix.max().item()) + 1
flat = rearrange(image, "c h w -> (h w) c")
out = torch.zeros((u, flat.shape[1]), dtype=flat.dtype, device=flat.device)
out.index_add_(0, index_matrix.view(-1), flat)
return out
def indexed_to_image(index_matrix, unique_colors):
"""Build a CHW image from an index matrix and a (U, C) color table."""
h, w = index_matrix.shape
flat = unique_colors[index_matrix.view(-1)]
return rearrange(flat, "(h w) c -> c h w", h=h, w=w)
def regaussianize(noise):
"""Variance-preserving re-sampling of a CHW noise tensor.
Wherever the noise contains groups of identical pixel values (e.g. after
a nearest-neighbor warp that duplicated source pixels), adds zero-mean
foreign noise within each group and scales by ``1/sqrt(count)`` so the
output is unit-variance gaussian again.
"""
_, hs, ws = noise.shape
_, counts, index_matrix = unique_pixels(noise[:1])
foreign_noise = torch.randn_like(noise)
summed = sum_indexed_values(foreign_noise, index_matrix)
meaned = indexed_to_image(index_matrix, summed / rearrange(counts, "u -> u 1"))
zeroed_foreign = foreign_noise - meaned
counts_image = indexed_to_image(index_matrix, rearrange(counts, "u -> u 1"))
output = noise / counts_image ** 0.5 + zeroed_foreign
return output, counts_image
def xy_meshgrid_like_image(image):
"""Return a (2, H, W) tensor of (x, y) pixel coordinates matching ``image``."""
_, h, w = image.shape
y, x = torch.meshgrid(
torch.arange(h, device=image.device, dtype=image.dtype),
torch.arange(w, device=image.device, dtype=image.dtype),
indexing="ij",
)
return torch.stack([x, y])
def noise_to_state(noise):
"""Pack a (C, H, W) noise tensor into a state tensor (3+C, H, W) = [dx, dy, ω, noise]."""
zeros = torch.zeros_like(noise[:1])
ones = torch.ones_like(noise[:1])
return torch.cat([zeros, zeros, ones, noise])
def state_to_noise(state):
"""Unpack the noise channels from a state tensor."""
return state[3:]
def warp_state(state, flow):
"""Warp a noise-warper state tensor along the given optical flow.
``state`` has shape ``(3+c, h, w)`` (= dx, dy, ω, c noise channels).
``flow`` has shape ``(2, h, w)`` (= dx, dy).
"""
if flow.device != state.device:
raise ValueError(
f"warp_state: flow and state must be on the same device, "
f"got flow={flow.device} state={state.device}"
)
if state.ndim != 3:
raise ValueError(
f"warp_state: state must be 3D (3+C, H, W), got shape {tuple(state.shape)}"
)
xyoc, h, w = state.shape
if flow.shape != (2, h, w):
raise ValueError(
f"warp_state: flow must have shape (2, {h}, {w}), got {tuple(flow.shape)}"
)
device = state.device
x_ch, y_ch = 0, 1
xy = 2 # state[:xy] = [dx, dy]
xyw = 3 # state[:xyw] = [dx, dy, ω]
w_ch = 2 # state[w_ch] = ω
c = xyoc - xyw
oc = xyoc - xy
if c <= 0:
raise ValueError(
f"warp_state: state has no noise channels (expected 3+C with C>0, got {xyoc} channels)"
)
if not (state[w_ch] > 0).all():
raise ValueError("warp_state: all weights in state[2] must be > 0")
grid = xy_meshgrid_like_image(state)
init = torch.empty_like(state)
init[:xy] = 0
init[w_ch] = 1
init[-c:] = 0
# --- Expansion branch: nearest-neighbor remap with negated flow ---
pre_expand = torch.empty_like(state)
pre_expand[:xy] = _torch_remap_relative(state[:xy], -flow[0], -flow[1], "nearest")
pre_expand[-oc:] = _torch_remap_relative(state[-oc:], -flow[0], -flow[1], "nearest")
pre_expand[w_ch][pre_expand[w_ch] == 0] = 1
# --- Shrink branch: scatter-add state into new positions ---
pre_shrink = state.clone()
pre_shrink[:xy] += flow
pos = (grid + pre_shrink[:xy]).round()
in_bounds = (pos[x_ch] >= 0) & (pos[x_ch] < w) & (pos[y_ch] >= 0) & (pos[y_ch] < h)
pre_shrink = torch.where(~in_bounds[None], init, pre_shrink)
scat_xy = pre_shrink[:xy].round()
pre_shrink[:xy] -= scat_xy
pre_shrink[:xy] = 0 # xy_mode='none' in upstream
def scat(tensor):
return _torch_scatter_add_relative(tensor, scat_xy[0], scat_xy[1])
# rp.torch_scatter_add_image on a bool tensor errors on modern torch;
# scatter-sum a float ones tensor and threshold to get the mask instead.
shrink_mask = scat(torch.ones(1, h, w, dtype=state.dtype, device=device)) > 0
# Drop expansion samples at positions that will be filled by shrink.
pre_expand = torch.where(shrink_mask, init, pre_expand)
# Regaussianize both branches together so duplicated-source groups are
# counted globally, then split back apart.
concat = torch.cat([pre_shrink, pre_expand], dim=2) # along width
concat[-c:], counts_image = regaussianize(concat[-c:])
concat[w_ch] = concat[w_ch] / counts_image[0]
concat[w_ch] = concat[w_ch].nan_to_num()
pre_shrink, expand = torch.chunk(concat, chunks=2, dim=2)
shrink = torch.empty_like(pre_shrink)
shrink[w_ch] = scat(pre_shrink[w_ch][None])[0]
shrink[:xy] = scat(pre_shrink[:xy] * pre_shrink[w_ch][None]) / shrink[w_ch][None]
shrink[-c:] = scat(pre_shrink[-c:] * pre_shrink[w_ch][None]) / scat(
pre_shrink[w_ch][None] ** 2
).sqrt()
output = torch.where(shrink_mask, shrink, expand)
output[w_ch] = output[w_ch] / output[w_ch].mean()
output[w_ch] += 1e-5
output[w_ch] **= 0.9999
return output
class NoiseWarper:
"""Maintain a warpable noise state and emit gaussian noise per frame.
Simplified from RyannDaGreat/CommonSource/noise_warp.py::NoiseWarper:
``scale_factor``, ``post_noise_alpha``, ``progressive_noise_alpha``, and
``warp_kwargs`` are all dropped since VOIDWarpedNoise always uses defaults.
"""
def __init__(self, c, h, w, device, dtype=torch.float32):
if c <= 0 or h <= 0 or w <= 0:
raise ValueError(
f"NoiseWarper: c/h/w must all be positive, got c={c} h={h} w={w}"
)
self.c = c
self.h = h
self.w = w
self.device = device
self.dtype = dtype
noise = torch.randn(c, h, w, dtype=dtype, device=device)
self._state = noise_to_state(noise)
@property
def noise(self):
# With scale_factor=1 the "downsample to respect weights" step is a
# size-preserving no-op; the weight-variance correction math still
# runs to stay faithful to upstream.
n = state_to_noise(self._state)
weights = self._state[2:3]
return n * weights / (weights ** 2).sqrt()
def __call__(self, dx, dy):
if dx.shape != dy.shape:
raise ValueError(
f"NoiseWarper: dx and dy must match, got {tuple(dx.shape)} vs {tuple(dy.shape)}"
)
flow = torch.stack([dx, dy]).to(self.device, self.dtype)
_, oflowh, ofloww = flow.shape
flow = _torch_resize_chw(flow, (self.h, self.w), "bilinear", copy=True)
flowh, floww = flow.shape[-2:]
# Upstream scales flow[0] by flowh/oflowh and flow[1] by floww/ofloww
# (channel-order appears swapped but harmless when H and W are scaled
# by the same factor, which is always the case for our callers).
flow[0] *= flowh / oflowh
flow[1] *= floww / ofloww
self._state = warp_state(self._state, flow)
return self
# ---------------------------------------------------------------------------
# RAFT optical flow wrapper (ported from raft.py)
# ---------------------------------------------------------------------------
class RaftOpticalFlow:
"""RAFT-large wrapper around a pre-loaded torchvision model.
``model`` must be the ``torchvision.models.optical_flow.raft_large`` module
with its weights already populated; this class is load-agnostic so the
caller owns downloading/offload concerns (see ``OpticalFlowLoader`` in
``nodes_void.py``). ``__call__`` returns a ``(2, H, W)`` flow.
"""
def __init__(self, model, device=None):
if device is None:
device = comfy.model_management.get_torch_device()
device = torch.device(device) if not isinstance(device, torch.device) else device
model = model.to(device)
model.eval()
self.device = device
self.model = model
def _preprocess(self, image_chw):
image = image_chw.to(self.device, torch.float32)
_, h, w = image.shape
new_h = (h // 8) * 8
new_w = (w // 8) * 8
image = _torch_resize_chw(image, (new_h, new_w), "bilinear", copy=False)
image = image * 2 - 1
return image[None]
def __call__(self, from_image, to_image):
"""``from_image``, ``to_image``: CHW float tensors in [0, 1]."""
if from_image.shape != to_image.shape:
raise ValueError(
f"RaftOpticalFlow: from_image and to_image must match, "
f"got {tuple(from_image.shape)} vs {tuple(to_image.shape)}"
)
_, h, w = from_image.shape
with torch.no_grad():
img1 = self._preprocess(from_image)
img2 = self._preprocess(to_image)
list_of_flows = self.model(img1, img2)
flow = list_of_flows[-1][0] # (2, new_h, new_w)
if flow.shape[-2:] != (h, w):
flow = _torch_resize_chw(flow, (h, w), "bilinear", copy=False)
return flow
# ---------------------------------------------------------------------------
# Narrow entry point used by VOIDWarpedNoise
# ---------------------------------------------------------------------------
def get_noise_from_video(
video_frames: torch.Tensor,
raft: RaftOpticalFlow,
*,
noise_channels: int = 16,
resize_frames: float = 0.5,
resize_flow: int = 8,
downscale_factor: int = 32,
device: Optional[torch.device] = None,
) -> torch.Tensor:
"""Produce optical-flow-warped gaussian noise from a video.
Args:
video_frames: ``(T, H, W, 3)`` uint8 torch tensor.
raft: Pre-loaded RAFT optical-flow wrapper (see ``RaftOpticalFlow``).
noise_channels: Channels in the output noise.
resize_frames: Pre-RAFT frame scale factor.
resize_flow: Post-flow up-scale factor applied to the optical flow;
the internal noise state is allocated at
``(resize_flow * resize_frames * H, resize_flow * resize_frames * W)``.
downscale_factor: Area-pool factor applied to the noise before return;
should evenly divide the internal noise resolution.
device: Target device. Defaults to ``comfy.model_management.get_torch_device()``.
Returns:
``(T, H', W', noise_channels)`` float32 noise tensor on ``device``.
"""
if not isinstance(resize_flow, int) or resize_flow < 1:
raise ValueError(
f"get_noise_from_video: resize_flow must be a positive int, got {resize_flow!r}"
)
if video_frames.ndim != 4 or video_frames.shape[-1] != 3:
raise ValueError(
"get_noise_from_video: video_frames must have shape (T, H, W, 3), "
f"got {tuple(video_frames.shape)}"
)
if video_frames.dtype != torch.uint8:
raise TypeError(
"get_noise_from_video: video_frames must be uint8 in [0, 255], "
f"got dtype {video_frames.dtype}"
)
if device is None:
device = comfy.model_management.get_torch_device()
device = torch.device(device) if not isinstance(device, torch.device) else device
if device.type == "cpu":
logging.warning(
"VOIDWarpedNoise: running get_noise_from_video on CPU; this will be "
"slow (minutes for ~45 frames). Use CUDA for interactive use."
)
T = video_frames.shape[0]
frames = video_frames.to(device).permute(0, 3, 1, 2).to(torch.float32) / 255.0
if resize_frames != 1.0:
new_h = max(1, int(frames.shape[2] * resize_frames))
new_w = max(1, int(frames.shape[3] * resize_frames))
frames = F.interpolate(frames, size=(new_h, new_w), mode="area")
_, _, H, W = frames.shape
internal_h = resize_flow * H
internal_w = resize_flow * W
if internal_h % downscale_factor or internal_w % downscale_factor:
logging.warning(
"VOIDWarpedNoise: internal noise size %dx%d is not divisible by "
"downscale_factor %d; output noise may have artifacts.",
internal_h, internal_w, downscale_factor,
)
with torch.no_grad():
warper = NoiseWarper(
c=noise_channels, h=internal_h, w=internal_w, device=device,
)
down_h = warper.h // downscale_factor
down_w = warper.w // downscale_factor
output = torch.empty(
(T, down_h, down_w, noise_channels), dtype=torch.float32, device=device,
)
def downscale(noise_chw):
# Area-pool to 1/downscale_factor then multiply by downscale_factor
# to adjust std (sqrt of pool area == downscale_factor for a
# square pool).
down = _torch_resize_chw(noise_chw, 1.0 / downscale_factor, "area", copy=False)
return down * downscale_factor
output[0] = downscale(warper.noise).permute(1, 2, 0)
prev = frames[0]
for i in range(1, T):
curr = frames[i]
flow = raft(prev, curr).to(device)
warper(flow[0], flow[1])
output[i] = downscale(warper.noise).permute(1, 2, 0)
prev = curr
return output

View File

@ -15,6 +15,7 @@ import torch
from comfy.cli_args import args from comfy.cli_args import args
import comfy.memory_management import comfy.memory_management
import comfy.model_management import comfy.model_management
import comfy.model_prefetch
import comfy_aimdo.model_vbar import comfy_aimdo.model_vbar
from latent_preview import set_preview_method from latent_preview import set_preview_method
@ -537,6 +538,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
if args.verbose == "DEBUG": if args.verbose == "DEBUG":
comfy_aimdo.control.analyze() comfy_aimdo.control.analyze()
comfy.model_management.reset_cast_buffers() comfy.model_management.reset_cast_buffers()
comfy.model_prefetch.cleanup_prefetch_queues()
comfy_aimdo.model_vbar.vbars_reset_watermark_limits() comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
if has_pending_tasks: if has_pending_tasks:
@ -1017,7 +1019,12 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None):
combo_options = extra_info.get("options", []) combo_options = extra_info.get("options", [])
else: else:
combo_options = input_type combo_options = input_type
if val not in combo_options: is_multiselect = extra_info.get("multiselect", False)
if is_multiselect and isinstance(val, list):
invalid_vals = [v for v in val if v not in combo_options]
else:
invalid_vals = [val] if val not in combo_options else []
if invalid_vals:
input_config = info input_config = info
list_info = "" list_info = ""
@ -1032,7 +1039,7 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None):
error = { error = {
"type": "value_not_in_list", "type": "value_not_in_list",
"message": "Value not in list", "message": "Value not in list",
"details": f"{x}: '{val}' not in {list_info}", "details": f"{x}: {', '.join(repr(v) for v in invalid_vals)} not in {list_info}",
"extra_info": { "extra_info": {
"input_name": x, "input_name": x,
"input_config": input_config, "input_config": input_config,

View File

@ -28,7 +28,7 @@
#config for a1111 ui #config for a1111 ui
#all you have to do is uncomment this (remove the #) and change the base_path to where yours is installed #all you have to do is uncomment this (remove the #) and change the base_path to where yours is installed
#a111: #a1111:
# base_path: path/to/stable-diffusion-webui/ # base_path: path/to/stable-diffusion-webui/
# checkpoints: models/Stable-diffusion # checkpoints: models/Stable-diffusion
# configs: models/Stable-diffusion # configs: models/Stable-diffusion

View File

@ -54,6 +54,8 @@ folder_names_and_paths["audio_encoders"] = ([os.path.join(models_dir, "audio_enc
folder_names_and_paths["frame_interpolation"] = ([os.path.join(models_dir, "frame_interpolation")], supported_pt_extensions) folder_names_and_paths["frame_interpolation"] = ([os.path.join(models_dir, "frame_interpolation")], supported_pt_extensions)
folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions)
output_directory = os.path.join(base_path, "output") output_directory = os.path.join(base_path, "output")
temp_directory = os.path.join(base_path, "temp") temp_directory = os.path.join(base_path, "temp")
input_directory = os.path.join(base_path, "input") input_directory = os.path.join(base_path, "input")
@ -432,7 +434,9 @@ def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, im
prefix_len = len(os.path.basename(filename_prefix)) prefix_len = len(os.path.basename(filename_prefix))
prefix = filename[:prefix_len + 1] prefix = filename[:prefix_len + 1]
try: try:
digits = int(filename[prefix_len + 1:].split('_')[0]) remainder = filename[prefix_len + 1:]
base_remainder = remainder.split('.')[0]
digits = int(base_remainder.split('_')[0])
except: except:
digits = 0 digits = 0
return digits, prefix return digits, prefix

10
main.py
View File

@ -1,13 +1,21 @@
import comfy.options import comfy.options
comfy.options.enable_args_parsing() comfy.options.enable_args_parsing()
from comfy.cli_args import args
if args.list_feature_flags:
import json
from comfy_api.feature_flags import CLI_FEATURE_FLAG_REGISTRY
print(json.dumps(CLI_FEATURE_FLAG_REGISTRY, indent=2)) # noqa: T201
raise SystemExit(0)
import os import os
import importlib.util import importlib.util
import shutil import shutil
import importlib.metadata import importlib.metadata
import folder_paths import folder_paths
import time import time
from comfy.cli_args import args, enables_dynamic_vram from comfy.cli_args import enables_dynamic_vram
from app.logger import setup_logger from app.logger import setup_logger
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)

View File

@ -86,6 +86,6 @@ def image_alpha_fix(destination, source):
if destination.shape[-1] < source.shape[-1]: if destination.shape[-1] < source.shape[-1]:
source = source[...,:destination.shape[-1]] source = source[...,:destination.shape[-1]]
elif destination.shape[-1] > source.shape[-1]: elif destination.shape[-1] > source.shape[-1]:
destination = torch.nn.functional.pad(destination, (0, 1)) source = torch.nn.functional.pad(source, (0, 1))
destination[..., -1] = 1.0 source[..., -1] = 1.0
return destination, source return destination, source

128
nodes.py
View File

@ -958,7 +958,7 @@ class CLIPLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image"], ), "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox"], ),
}, },
"optional": { "optional": {
"device": (["default", "cpu"], {"advanced": True}), "device": (["default", "cpu"], {"advanced": True}),
@ -968,7 +968,7 @@ class CLIPLoader:
CATEGORY = "advanced/loaders" CATEGORY = "advanced/loaders"
DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B" DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncogvideox: t5 xxl (226-token padding)\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B"
def load_clip(self, clip_name, type="stable_diffusion", device="default"): def load_clip(self, clip_name, type="stable_diffusion", device="default"):
clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION)
@ -1694,26 +1694,27 @@ class LoadImage:
RETURN_TYPES = ("IMAGE", "MASK") RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "load_image" FUNCTION = "load_image"
def load_image(self, image): def load_image(self, image):
image_path = folder_paths.get_annotated_filepath(image) image_path = folder_paths.get_annotated_filepath(image)
dtype = comfy.model_management.intermediate_dtype()
device = comfy.model_management.intermediate_device()
components = InputImpl.VideoFromFile(image_path).get_components() components = InputImpl.VideoFromFile(image_path).get_components()
if components.images.shape[0] > 0: if components.images.shape[0] > 0:
return (components.images, 1.0 - components.alpha[..., -1] if components.alpha is not None else torch.zeros((components.images.shape[0], 64, 64), dtype=torch.float32, device="cpu")) return (components.images.to(device=device, dtype=dtype), (1.0 - components.alpha[..., -1]).to(device=device, dtype=dtype) if components.alpha is not None else torch.zeros((components.images.shape[0], 64, 64), dtype=dtype, device=device))
# This code is left here to handle animated webp which pyav does not support loading
img = node_helpers.pillow(Image.open, image_path) img = node_helpers.pillow(Image.open, image_path)
output_images = [] output_images = []
output_masks = [] output_masks = []
w, h = None, None w, h = None, None
dtype = comfy.model_management.intermediate_dtype()
for i in ImageSequence.Iterator(img): for i in ImageSequence.Iterator(img):
i = node_helpers.pillow(ImageOps.exif_transpose, i) i = node_helpers.pillow(ImageOps.exif_transpose, i)
if i.mode == 'I':
i = i.point(lambda i: i * (1 / 255))
image = i.convert("RGB") image = i.convert("RGB")
if len(output_images) == 0: if len(output_images) == 0:
@ -1728,25 +1729,15 @@ class LoadImage:
if 'A' in i.getbands(): if 'A' in i.getbands():
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0 mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask) mask = 1. - torch.from_numpy(mask)
elif i.mode == 'P' and 'transparency' in i.info:
mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
else: else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu")
output_images.append(image.to(dtype=dtype)) output_images.append(image.to(dtype=dtype))
output_masks.append(mask.unsqueeze(0).to(dtype=dtype)) output_masks.append(mask.unsqueeze(0).to(dtype=dtype))
if img.format == "MPO": output_image = torch.cat(output_images, dim=0)
break # ignore all frames except the first one for MPO format output_mask = torch.cat(output_masks, dim=0)
if len(output_images) > 1: return (output_image.to(device=device, dtype=dtype), output_mask.to(device=device, dtype=dtype))
output_image = torch.cat(output_images, dim=0)
output_mask = torch.cat(output_masks, dim=0)
else:
output_image = output_images[0]
output_mask = output_masks[0]
return (output_image, output_mask)
@classmethod @classmethod
def IS_CHANGED(s, image): def IS_CHANGED(s, image):
@ -1763,57 +1754,49 @@ class LoadImage:
return True return True
class LoadImageMask:
class LoadImageMask(LoadImage):
ESSENTIALS_CATEGORY = "Image Tools" ESSENTIALS_CATEGORY = "Image Tools"
SEARCH_ALIASES = ["import mask", "alpha mask", "channel mask"] SEARCH_ALIASES = ["import mask", "alpha mask", "channel mask"]
_color_channels = ["alpha", "red", "green", "blue"] _color_channels = ["alpha", "red", "green", "blue"]
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory() types = super().INPUT_TYPES()
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] return {
return {"required": "required": {
{"image": (sorted(files), {"image_upload": True}), **types["required"],
"channel": (s._color_channels, ), } "channel": (s._color_channels, )
} }
}
CATEGORY = "mask" CATEGORY = "mask"
RETURN_TYPES = ("MASK",) RETURN_TYPES = ("MASK",)
FUNCTION = "load_image" FUNCTION = "load_image_mask"
def load_image(self, image, channel):
image_path = folder_paths.get_annotated_filepath(image) def load_image_mask(self, image, channel):
i = node_helpers.pillow(Image.open, image_path) image_tensor, mask_tensor = super().load_image(image)
i = node_helpers.pillow(ImageOps.exif_transpose, i)
if i.getbands() != ("R", "G", "B", "A"):
if i.mode == 'I':
i = i.point(lambda i: i * (1 / 255))
i = i.convert("RGBA")
mask = None
c = channel[0].upper() c = channel[0].upper()
if c in i.getbands():
mask = np.array(i.getchannel(c)).astype(np.float32) / 255.0 if c == 'A':
mask = torch.from_numpy(mask) return (mask_tensor,)
if c == 'A':
mask = 1. - mask channel_idx = {'R': 0, 'G': 1, 'B': 2}.get(c, 0)
if channel_idx < image_tensor.shape[-1]:
return (image_tensor[..., channel_idx].clone(),)
else: else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") empty_mask = torch.zeros(
return (mask.unsqueeze(0),) image_tensor.shape[:-1],
dtype=image_tensor.dtype,
device=image_tensor.device
)
return (empty_mask,)
@classmethod @classmethod
def IS_CHANGED(s, image, channel): def IS_CHANGED(s, image, channel):
image_path = folder_paths.get_annotated_filepath(image) return super().IS_CHANGED(image)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
m.update(f.read())
return m.digest().hex()
@classmethod
def VALIDATE_INPUTS(s, image):
if not folder_paths.exists_annotated_filepath(image):
return "Invalid image file: {}".format(image)
return True
class LoadImageOutput(LoadImage): class LoadImageOutput(LoadImage):
@ -1904,7 +1887,7 @@ class ImageInvert:
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("IMAGE",)
FUNCTION = "invert" FUNCTION = "invert"
CATEGORY = "image" CATEGORY = "image/color"
def invert(self, image): def invert(self, image):
s = 1.0 - image s = 1.0 - image
@ -1920,7 +1903,7 @@ class ImageBatch:
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("IMAGE",)
FUNCTION = "batch" FUNCTION = "batch"
CATEGORY = "image" CATEGORY = "image/batch"
DEPRECATED = True DEPRECATED = True
def batch(self, image1, image2): def batch(self, image1, image2):
@ -1977,7 +1960,7 @@ class ImagePadForOutpaint:
RETURN_TYPES = ("IMAGE", "MASK") RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "expand_image" FUNCTION = "expand_image"
CATEGORY = "image" CATEGORY = "image/transform"
def expand_image(self, image, left, top, right, bottom, feathering): def expand_image(self, image, left, top, right, bottom, feathering):
d1, d2, d3, d4 = image.size() d1, d2, d3, d4 = image.size()
@ -2120,7 +2103,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"ConditioningSetArea": "Conditioning (Set Area)", "ConditioningSetArea": "Conditioning (Set Area)",
"ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)", "ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)",
"ConditioningSetMask": "Conditioning (Set Mask)", "ConditioningSetMask": "Conditioning (Set Mask)",
"ControlNetApply": "Apply ControlNet (OLD)", "ControlNetApply": "Apply ControlNet (DEPRECATED)",
"ControlNetApplyAdvanced": "Apply ControlNet", "ControlNetApplyAdvanced": "Apply ControlNet",
# Latent # Latent
"VAEEncodeForInpaint": "VAE Encode (for Inpainting)", "VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
@ -2138,6 +2121,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"LatentFromBatch" : "Latent From Batch", "LatentFromBatch" : "Latent From Batch",
"RepeatLatentBatch": "Repeat Latent Batch", "RepeatLatentBatch": "Repeat Latent Batch",
# Image # Image
"EmptyImage": "Empty Image",
"SaveImage": "Save Image", "SaveImage": "Save Image",
"PreviewImage": "Preview Image", "PreviewImage": "Preview Image",
"LoadImage": "Load Image", "LoadImage": "Load Image",
@ -2145,15 +2129,15 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"LoadImageOutput": "Load Image (from Outputs)", "LoadImageOutput": "Load Image (from Outputs)",
"ImageScale": "Upscale Image", "ImageScale": "Upscale Image",
"ImageScaleBy": "Upscale Image By", "ImageScaleBy": "Upscale Image By",
"ImageInvert": "Invert Image", "ImageInvert": "Invert Image Colors",
"ImagePadForOutpaint": "Pad Image for Outpainting", "ImagePadForOutpaint": "Pad Image for Outpainting",
"ImageBatch": "Batch Images", "ImageBatch": "Batch Images (DEPRECATED)",
"ImageCrop": "Image Crop", "ImageCrop": "Crop Image",
"ImageStitch": "Image Stitch", "ImageStitch": "Stitch Images",
"ImageBlend": "Image Blend", "ImageBlend": "Blend Images",
"ImageBlur": "Image Blur", "ImageBlur": "Blur Image",
"ImageQuantize": "Image Quantize", "ImageQuantize": "Quantize Image",
"ImageSharpen": "Image Sharpen", "ImageSharpen": "Sharpen Image",
"ImageScaleToTotalPixels": "Scale Image to Total Pixels", "ImageScaleToTotalPixels": "Scale Image to Total Pixels",
"GetImageSize": "Get Image Size", "GetImageSize": "Get Image Size",
# _for_testing # _for_testing
@ -2278,7 +2262,7 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
logging.warning(f"Error while calling comfy_entrypoint in {module_path}: {e}") logging.warning(f"Error while calling comfy_entrypoint in {module_path}: {e}")
return False return False
else: else:
logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS or NODES_LIST (need one).") logging.warning(f"Skip {module_path} module for custom nodes due to the lack of NODE_CLASS_MAPPINGS or comfy_entrypoint (need one).")
return False return False
except Exception as e: except Exception as e:
logging.warning(traceback.format_exc()) logging.warning(traceback.format_exc())
@ -2428,6 +2412,7 @@ async def init_builtin_extra_nodes():
"nodes_nop.py", "nodes_nop.py",
"nodes_kandinsky5.py", "nodes_kandinsky5.py",
"nodes_wanmove.py", "nodes_wanmove.py",
"nodes_ar_video.py",
"nodes_image_compare.py", "nodes_image_compare.py",
"nodes_zimage.py", "nodes_zimage.py",
"nodes_glsl.py", "nodes_glsl.py",
@ -2445,6 +2430,7 @@ async def init_builtin_extra_nodes():
"nodes_rtdetr.py", "nodes_rtdetr.py",
"nodes_frame_interpolation.py", "nodes_frame_interpolation.py",
"nodes_sam3.py", "nodes_sam3.py",
"nodes_void.py",
] ]
import_failed = [] import_failed = []

View File

@ -631,7 +631,7 @@ paths:
operationId: getFeatures operationId: getFeatures
tags: [system] tags: [system]
summary: Get enabled feature flags summary: Get enabled feature flags
description: Returns a dictionary of feature flag names to their enabled state. description: Returns a dictionary of feature flag names to their enabled state. Cloud deployments may include additional typed fields alongside the boolean flags.
responses: responses:
"200": "200":
description: Feature flags description: Feature flags
@ -641,6 +641,43 @@ paths:
type: object type: object
additionalProperties: additionalProperties:
type: boolean type: boolean
properties:
max_upload_size:
type: integer
format: int64
minimum: 0
description: "Maximum file upload size in bytes."
free_tier_credits:
type: integer
format: int32
minimum: 0
nullable: true
x-runtime: [cloud]
description: "[cloud-only] Credits available to free-tier users. Local ComfyUI returns null."
posthog_api_host:
type: string
format: uri
nullable: true
x-runtime: [cloud]
description: "[cloud-only] PostHog analytics proxy URL for frontend telemetry. Local ComfyUI returns null."
max_concurrent_jobs:
type: integer
format: int32
minimum: 0
nullable: true
x-runtime: [cloud]
description: "[cloud-only] Maximum concurrent jobs the authenticated user can run. Local ComfyUI returns null."
workflow_templates_version:
type: string
nullable: true
x-runtime: [cloud]
description: "[cloud-only] Version identifier for the workflow templates bundle. Local ComfyUI returns null."
workflow_templates_source:
type: string
nullable: true
enum: [dynamic_config_override, workflow_templates_version_json]
x-runtime: [cloud]
description: "[cloud-only] How the templates version was resolved. Local ComfyUI returns null."
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Node / Object Info # Node / Object Info
@ -1497,6 +1534,24 @@ paths:
type: string type: string
enum: [asc, desc] enum: [asc, desc]
description: Sort direction description: Sort direction
- name: job_ids
in: query
schema:
type: string
x-runtime: [cloud]
description: "[cloud-only] Comma-separated UUIDs to filter assets by associated job."
- name: include_public
in: query
schema:
type: boolean
x-runtime: [cloud]
description: "[cloud-only] Include workspace-public assets in addition to the caller's own."
- name: asset_hash
in: query
schema:
type: string
x-runtime: [cloud]
description: "[cloud-only] Filter by exact content hash."
responses: responses:
"200": "200":
description: Asset list description: Asset list
@ -1542,6 +1597,49 @@ paths:
type: string type: string
format: uuid format: uuid
description: ID of an existing asset to use as the preview image description: ID of an existing asset to use as the preview image
id:
type: string
format: uuid
nullable: true
x-runtime: [cloud]
description: "[cloud-only] Client-supplied asset ID for idempotent creation. If an asset with this ID already exists, the existing asset is returned."
application/json:
schema:
type: object
x-runtime: [cloud]
description: "[cloud-only] URL-based asset upload. Caller supplies a URL instead of a file body; the server fetches the content."
required:
- url
properties:
url:
type: string
format: uri
description: "[cloud-only] URL of the file to import as an asset"
name:
type: string
description: Display name for the asset
tags:
type: string
description: Comma-separated tags
user_metadata:
type: string
description: JSON-encoded user metadata
hash:
type: string
description: "Blake3 hash of the file content (e.g. blake3:abc123...)"
mime_type:
type: string
description: MIME type of the file (overrides auto-detected type)
preview_id:
type: string
format: uuid
description: ID of an existing asset to use as the preview image
id:
type: string
format: uuid
nullable: true
x-runtime: [cloud]
description: "[cloud-only] Client-supplied asset ID for idempotent creation. If an asset with this ID already exists, the existing asset is returned."
responses: responses:
"201": "201":
description: Asset created description: Asset created
@ -1580,6 +1678,11 @@ paths:
user_metadata: user_metadata:
type: object type: object
additionalProperties: true additionalProperties: true
mime_type:
type: string
nullable: true
x-runtime: [cloud]
description: "[cloud-only] MIME type of the content, so the type is preserved without re-inspecting content. Ignored by local ComfyUI."
responses: responses:
"201": "201":
description: Asset created from hash description: Asset created from hash
@ -1644,6 +1747,11 @@ paths:
type: string type: string
format: uuid format: uuid
description: ID of the asset to use as the preview description: ID of the asset to use as the preview
mime_type:
type: string
nullable: true
x-runtime: [cloud]
description: "[cloud-only] MIME type override when auto-detection was wrong. Ignored by local ComfyUI."
responses: responses:
"200": "200":
description: Asset updated description: Asset updated
@ -1999,6 +2107,18 @@ components:
items: items:
type: string type: string
description: List of node IDs to execute (partial graph execution) description: List of node IDs to execute (partial graph execution)
workflow_id:
type: string
format: uuid
nullable: true
x-runtime: [cloud]
description: "[cloud-only] Cloud workflow entity ID for tracking and gallery association. Ignored by local ComfyUI."
workflow_version_id:
type: string
format: uuid
nullable: true
x-runtime: [cloud]
description: "[cloud-only] Cloud workflow version ID for pinning execution to a specific version. Ignored by local ComfyUI."
PromptResponse: PromptResponse:
type: object type: object
@ -2347,7 +2467,12 @@ components:
description: Device type (cuda, mps, cpu, etc.) description: Device type (cuda, mps, cpu, etc.)
index: index:
type: number type: number
description: Device index nullable: true
description: |
Device index within its type (e.g. CUDA ordinal for `cuda:0`,
`cuda:1`). `null` for devices with no index, including the CPU
device returned in `--cpu` mode (PyTorch's `torch.device('cpu').index`
is `None`).
vram_total: vram_total:
type: number type: number
description: Total VRAM in bytes description: Total VRAM in bytes
@ -2503,7 +2628,18 @@ components:
description: Alternative search terms for finding this node description: Alternative search terms for finding this node
essentials_category: essentials_category:
type: string type: string
description: Category override used by the essentials pack nullable: true
description: |
Category override used by the essentials pack. The
`essentials_category` key may be present with a string value,
present and `null`, or absent entirely:
- V1 nodes: `essentials_category` is **omitted** when the node
class doesn't define an `ESSENTIALS_CATEGORY` attribute, and
**`null`** if the attribute is explicitly set to `None`.
- V3 nodes (`comfy_api.latest.io`): `essentials_category` is
**always present**, and **`null`** for nodes whose `Schema`
doesn't populate it.
# ------------------------------------------------------------------- # -------------------------------------------------------------------
# Models # Models

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.42.15 comfyui-frontend-package==1.43.17
comfyui-workflow-templates==0.9.66 comfyui-workflow-templates==0.9.69
comfyui-embedded-docs==0.4.4 comfyui-embedded-docs==0.4.4
torch torch
torchsde torchsde

View File

@ -1,3 +1,4 @@
import errno
import os import os
import sys import sys
import asyncio import asyncio
@ -560,7 +561,7 @@ class PromptServer():
buffer.seek(0) buffer.seek(0)
return web.Response(body=buffer.read(), content_type=f'image/{image_format}', return web.Response(body=buffer.read(), content_type=f'image/{image_format}',
headers={"Content-Disposition": f"filename=\"{filename}\""}) headers={"Content-Disposition": f"attachment; filename=\"{filename}\""})
if 'channel' not in request.rel_url.query: if 'channel' not in request.rel_url.query:
channel = 'rgba' channel = 'rgba'
@ -580,7 +581,7 @@ class PromptServer():
buffer.seek(0) buffer.seek(0)
return web.Response(body=buffer.read(), content_type='image/png', return web.Response(body=buffer.read(), content_type='image/png',
headers={"Content-Disposition": f"filename=\"{filename}\""}) headers={"Content-Disposition": f"attachment; filename=\"{filename}\""})
elif channel == 'a': elif channel == 'a':
with Image.open(file) as img: with Image.open(file) as img:
@ -597,7 +598,7 @@ class PromptServer():
alpha_buffer.seek(0) alpha_buffer.seek(0)
return web.Response(body=alpha_buffer.read(), content_type='image/png', return web.Response(body=alpha_buffer.read(), content_type='image/png',
headers={"Content-Disposition": f"filename=\"{filename}\""}) headers={"Content-Disposition": f"attachment; filename=\"{filename}\""})
else: else:
# Use the content type from asset resolution if available, # Use the content type from asset resolution if available,
# otherwise guess from the filename. # otherwise guess from the filename.
@ -614,7 +615,7 @@ class PromptServer():
return web.FileResponse( return web.FileResponse(
file, file,
headers={ headers={
"Content-Disposition": f"filename=\"{filename}\"", "Content-Disposition": f"attachment; filename=\"{filename}\"",
"Content-Type": content_type "Content-Type": content_type
} }
) )
@ -1246,7 +1247,13 @@ class PromptServer():
address = addr[0] address = addr[0]
port = addr[1] port = addr[1]
site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx) site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
await site.start() try:
await site.start()
except OSError as e:
if e.errno == errno.EADDRINUSE:
logging.error(f"Port {port} is already in use on address {address}. Please close the other application or use a different port with --port.")
raise SystemExit(1)
raise
if not hasattr(self, 'address'): if not hasattr(self, 'address'):
self.address = address #TODO: remove this self.address = address #TODO: remove this

View File

@ -0,0 +1,78 @@
from comfy_api.latest._io import Combo, MultiCombo
def test_multicombo_serializes_multi_select_as_object():
multi_combo = MultiCombo.Input(
id="providers",
options=["a", "b", "c"],
default=["a"],
)
serialized = multi_combo.as_dict()
assert serialized["multiselect"] is True
assert "multi_select" in serialized
assert serialized["multi_select"] == {}
def test_multicombo_serializes_multi_select_with_placeholder_and_chip():
multi_combo = MultiCombo.Input(
id="providers",
options=["a", "b", "c"],
default=["a"],
placeholder="Select providers",
chip=True,
)
serialized = multi_combo.as_dict()
assert serialized["multiselect"] is True
assert serialized["multi_select"] == {
"placeholder": "Select providers",
"chip": True,
}
def test_combo_does_not_serialize_multiselect():
"""Regular Combo should not have multiselect in its serialized output."""
combo = Combo.Input(
id="choice",
options=["a", "b", "c"],
)
serialized = combo.as_dict()
# Combo sets multiselect=False, but prune_dict keeps False (not None),
# so it should be present but False
assert serialized.get("multiselect") is False
assert "multi_select" not in serialized
def _validate_combo_values(val, combo_options, is_multiselect):
"""Reproduce the validation logic from execution.py for testing."""
if is_multiselect and isinstance(val, list):
return [v for v in val if v not in combo_options]
else:
return [val] if val not in combo_options else []
def test_multicombo_validation_accepts_valid_list():
options = ["a", "b", "c"]
assert _validate_combo_values(["a", "b"], options, True) == []
def test_multicombo_validation_rejects_invalid_values():
options = ["a", "b", "c"]
assert _validate_combo_values(["a", "x"], options, True) == ["x"]
def test_multicombo_validation_accepts_empty_list():
options = ["a", "b", "c"]
assert _validate_combo_values([], options, True) == []
def test_combo_validation_rejects_list_even_with_valid_items():
"""A regular Combo should not accept a list value."""
options = ["a", "b", "c"]
invalid = _validate_combo_values(["a", "b"], options, False)
assert len(invalid) > 0

View File

@ -0,0 +1,109 @@
"""Tests for comfy.deploy_environment."""
import os
import pytest
from comfy import deploy_environment
from comfy.deploy_environment import get_deploy_environment
@pytest.fixture(autouse=True)
def _reset_cache_and_install_dir(tmp_path, monkeypatch):
"""Reset the functools cache and point the ComfyUI install dir at a tmp dir for each test."""
get_deploy_environment.cache_clear()
monkeypatch.setattr(deploy_environment, "_COMFY_INSTALL_DIR", str(tmp_path))
yield
get_deploy_environment.cache_clear()
def _write_env_file(tmp_path, content: str) -> str:
"""Write the env file with exact content (no newline translation).
`newline=""` disables Python's text-mode newline translation so the bytes
on disk match the literal string passed in, regardless of host OS.
Newline-style tests (CRLF, lone CR) rely on this.
"""
path = os.path.join(str(tmp_path), ".comfy_environment")
with open(path, "w", encoding="utf-8", newline="") as f:
f.write(content)
return path
class TestGetDeployEnvironment:
def test_returns_local_git_when_file_missing(self):
assert get_deploy_environment() == "local-git"
def test_reads_value_from_file(self, tmp_path):
_write_env_file(tmp_path, "local-desktop2-standalone\n")
assert get_deploy_environment() == "local-desktop2-standalone"
def test_strips_trailing_whitespace_and_newline(self, tmp_path):
_write_env_file(tmp_path, " local-desktop2-standalone \n")
assert get_deploy_environment() == "local-desktop2-standalone"
def test_only_first_line_is_used(self, tmp_path):
_write_env_file(tmp_path, "first-line\nsecond-line\n")
assert get_deploy_environment() == "first-line"
def test_crlf_line_ending(self, tmp_path):
# Windows editors often save text files with CRLF line endings.
# The CR must not end up in the returned value.
_write_env_file(tmp_path, "local-desktop2-standalone\r\n")
assert get_deploy_environment() == "local-desktop2-standalone"
def test_crlf_multiline_only_first_line_used(self, tmp_path):
_write_env_file(tmp_path, "first-line\r\nsecond-line\r\n")
assert get_deploy_environment() == "first-line"
def test_crlf_with_surrounding_whitespace(self, tmp_path):
_write_env_file(tmp_path, " local-desktop2-standalone \r\n")
assert get_deploy_environment() == "local-desktop2-standalone"
def test_lone_cr_line_ending(self, tmp_path):
# Classic-Mac / some legacy editors use a bare CR.
# Universal-newlines decoding treats it as a line terminator too.
_write_env_file(tmp_path, "local-desktop2-standalone\r")
assert get_deploy_environment() == "local-desktop2-standalone"
def test_empty_file_falls_back_to_default(self, tmp_path):
_write_env_file(tmp_path, "")
assert get_deploy_environment() == "local-git"
def test_empty_after_whitespace_strip_falls_back_to_default(self, tmp_path):
_write_env_file(tmp_path, " \n")
assert get_deploy_environment() == "local-git"
def test_strips_control_chars_within_first_line(self, tmp_path):
# Embedded NUL/control chars in the value should be stripped
# (header-injection / smuggling protection).
_write_env_file(tmp_path, "abc\x00\x07xyz\n")
assert get_deploy_environment() == "abcxyz"
def test_strips_non_ascii_characters(self, tmp_path):
_write_env_file(tmp_path, "café-é\n")
assert get_deploy_environment() == "caf-"
def test_caps_read_at_128_bytes(self, tmp_path):
# A single huge line with no newline must not be fully read into memory.
huge = "x" * 10_000
_write_env_file(tmp_path, huge)
result = get_deploy_environment()
assert result == "x" * 128
def test_result_is_cached_across_calls(self, tmp_path):
path = _write_env_file(tmp_path, "first_value\n")
assert get_deploy_environment() == "first_value"
# Overwrite the file — cached value should still be returned.
with open(path, "w", encoding="utf-8") as f:
f.write("second_value\n")
assert get_deploy_environment() == "first_value"
def test_unreadable_file_falls_back_to_default(self, tmp_path, monkeypatch):
_write_env_file(tmp_path, "should_not_be_used\n")
def _boom(*args, **kwargs):
raise OSError("simulated read failure")
monkeypatch.setattr("builtins.open", _boom)
assert get_deploy_environment() == "local-git"

View File

@ -1,10 +1,15 @@
"""Tests for feature flags functionality.""" """Tests for feature flags functionality."""
import pytest
from comfy_api.feature_flags import ( from comfy_api.feature_flags import (
get_connection_feature, get_connection_feature,
supports_feature, supports_feature,
get_server_features, get_server_features,
CLI_FEATURE_FLAG_REGISTRY,
SERVER_FEATURE_FLAGS, SERVER_FEATURE_FLAGS,
_coerce_flag_value,
_parse_cli_feature_flags,
) )
@ -96,3 +101,83 @@ class TestFeatureFlags:
result = get_connection_feature(sockets_metadata, "sid1", "any_feature") result = get_connection_feature(sockets_metadata, "sid1", "any_feature")
assert result is False assert result is False
assert supports_feature(sockets_metadata, "sid1", "any_feature") is False assert supports_feature(sockets_metadata, "sid1", "any_feature") is False
class TestCoerceFlagValue:
"""Test suite for _coerce_flag_value."""
def test_registered_bool_true(self):
assert _coerce_flag_value("show_signin_button", "true") is True
assert _coerce_flag_value("show_signin_button", "True") is True
def test_registered_bool_false(self):
assert _coerce_flag_value("show_signin_button", "false") is False
assert _coerce_flag_value("show_signin_button", "FALSE") is False
def test_unregistered_key_stays_string(self):
assert _coerce_flag_value("unknown_flag", "true") == "true"
assert _coerce_flag_value("unknown_flag", "42") == "42"
def test_bool_typo_raises(self):
"""Strict bool: typos like 'ture' or 'yes' must raise so the flag can be dropped."""
with pytest.raises(ValueError):
_coerce_flag_value("show_signin_button", "ture")
with pytest.raises(ValueError):
_coerce_flag_value("show_signin_button", "yes")
with pytest.raises(ValueError):
_coerce_flag_value("show_signin_button", "1")
with pytest.raises(ValueError):
_coerce_flag_value("show_signin_button", "")
def test_failed_int_coercion_raises(self, monkeypatch):
"""Malformed values for typed flags must raise; caller decides what to do."""
monkeypatch.setitem(
CLI_FEATURE_FLAG_REGISTRY,
"test_int_flag",
{"type": "int", "default": 0, "description": "test"},
)
with pytest.raises(ValueError):
_coerce_flag_value("test_int_flag", "not_a_number")
class TestParseCliFeatureFlags:
"""Test suite for _parse_cli_feature_flags."""
def test_single_flag(self, monkeypatch):
monkeypatch.setattr("comfy_api.feature_flags.args", type("Args", (), {"feature_flag": ["show_signin_button=true"]})())
result = _parse_cli_feature_flags()
assert result == {"show_signin_button": True}
def test_missing_equals_defaults_to_true(self, monkeypatch):
"""Bare flag without '=' is treated as the string 'true' (and coerced if registered)."""
monkeypatch.setattr("comfy_api.feature_flags.args", type("Args", (), {"feature_flag": ["show_signin_button", "valid=1"]})())
result = _parse_cli_feature_flags()
assert result == {"show_signin_button": True, "valid": "1"}
def test_empty_key_skipped(self, monkeypatch):
monkeypatch.setattr("comfy_api.feature_flags.args", type("Args", (), {"feature_flag": ["=value", "valid=1"]})())
result = _parse_cli_feature_flags()
assert result == {"valid": "1"}
def test_invalid_bool_value_dropped(self, monkeypatch, caplog):
"""A typo'd bool value must be dropped entirely, not silently set to False
and not stored as a raw string. A warning must be logged."""
monkeypatch.setattr(
"comfy_api.feature_flags.args",
type("Args", (), {"feature_flag": ["show_signin_button=ture", "valid=1"]})(),
)
with caplog.at_level("WARNING"):
result = _parse_cli_feature_flags()
assert result == {"valid": "1"}
assert "show_signin_button" not in result
assert any("show_signin_button" in r.message and "drop" in r.message.lower() for r in caplog.records)
class TestCliFeatureFlagRegistry:
"""Test suite for the CLI feature flag registry."""
def test_registry_entries_have_required_fields(self):
for key, info in CLI_FEATURE_FLAG_REGISTRY.items():
assert "type" in info, f"{key} missing 'type'"
assert "default" in info, f"{key} missing 'default'"
assert "description" in info, f"{key} missing 'description'"

View File

@ -69,7 +69,11 @@ async def test_listuserdata_full_info(aiohttp_client, app, tmp_path):
assert len(result) == 1 assert len(result) == 1
assert result[0]["path"] == "file1.txt" assert result[0]["path"] == "file1.txt"
assert "size" in result[0] assert "size" in result[0]
assert "modified" in result[0] assert isinstance(result[0]["modified"], int)
assert isinstance(result[0]["created"], int)
# Verify millisecond magnitude (timestamps after year 2000 in ms are > 946684800000)
assert result[0]["modified"] > 946684800000
assert result[0]["created"] > 946684800000
async def test_listuserdata_split_path(aiohttp_client, app, tmp_path): async def test_listuserdata_split_path(aiohttp_client, app, tmp_path):