Compare commits

...

76 Commits

Author SHA1 Message Date
comfyanonymous
0b04660ba3
Speed up anima a bit on nvidia. (#14181)
Some checks are pending
Detect Unreviewed Merge / detect (push) Waiting to run
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-05-29 22:47:10 -07:00
comfyanonymous
6e1ef2311b
Remove useless code. (#14178)
Some checks are pending
Detect Unreviewed Merge / detect (push) Waiting to run
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-05-29 16:26:46 -07:00
Alexander Piskun
ec1896aceb
[Partner Nodes] feat: add new nodes for Tripo3D P1 model (#14155)
Some checks are pending
Detect Unreviewed Merge / detect (push) Waiting to run
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-05-29 09:19:53 -07:00
Jukka Seppänen
54d5be4a8e
Fix background removal mask output shape (#14171) 2026-05-29 09:14:32 -07:00
Alexander Piskun
ea5b092576
[Partner Nodes] fix: removed "beta" models versions from Grok nodes (#14170) 2026-05-29 09:08:43 -07:00
Terry Jia
e7214d78ee
feat: add model_info output to Load3D node (#14144)
Some checks are pending
Detect Unreviewed Merge / detect (push) Waiting to run
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-05-29 00:06:00 -07:00
Daxiong (Lin)
b10a61615c
chore: update workflow templates to v0.9.91 (#14163) 2026-05-28 22:42:17 -07:00
rattus
684296148e
float: use CK stochastic rounding cuda kernel (#13971)
Some checks are pending
Detect Unreviewed Merge / detect (push) Waiting to run
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
2026-05-28 19:23:42 -07:00
comfyanonymous
ade4dfd96a
Update and pin comfy-kitchen version to 0.2.9 (#14161) 2026-05-28 19:23:17 -07:00
Terry Jia
26aad73cd7
refactor: drop rotation from Load3DCamera (#14159) 2026-05-28 17:42:47 -07:00
comfyanonymous
bcf805aaea
Bump pyav package to fix some image loading issues. (#14160) 2026-05-28 17:38:01 -07:00
Luke Mino-Altherr
6dd3c67427
Add unreviewed merge detector for SOC 2 compliance (#14146) 2026-05-28 15:07:22 -07:00
Charles Chan
6ceec29bd1
feat: add overwrite/increment to SaveImageTextDataSetToFolderNode (#13215)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-05-28 10:12:04 -07:00
Alexander Piskun
cffa2f43aa
[Partner Nodes] chore: update the category of the Beeble nodes (#14156)
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-05-28 19:23:51 +03:00
Alexander Piskun
4af9a47227
[Partner Nodes] fix: add runtime check for SeeDance2 image inputs (#14152)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-05-28 01:03:28 -07:00
Terry Jia
be06873d9b
Make Load3D model_file optional by adding "none" choice (#13379) 2026-05-27 23:16:28 -07:00
Terry Jia
8ed308bcde
feat: add camera intrinsics fields to Load3DCamera info (#14143) 2026-05-27 22:34:43 -07:00
Alexis Rolland
174208df6b
chore: Update nodes categories (#14145)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
* Move dataset/text nodes to text category

* Rename category utils into utilities

* Rename category api node into partner

* Move categories conditioning, latent, sampling, model_patches, training, etc. under model category

* Dispatch partner nodes in to 3d, audio, image, text, video categories

* Move PreviewAny node to utilities category
2026-05-27 20:43:33 -04:00
comfyanonymous
85a403d1ea
Disable sage attention in stable audio dit and VAE. (#14148) 2026-05-27 20:35:03 -04:00
Jukka Seppänen
987a937658
Support context window for PiD and fix lq_latent rounding (#14136) 2026-05-27 12:08:06 -07:00
Alexander Piskun
51ef17e8a6
[Partner Nodes] feat: Beeble SwitchX nodes (#14137)
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-05-27 11:57:55 -07:00
Alexander Piskun
b1cba6f4e6
convert nodes_lt_upsampler nodes to V3 schema (#12423) 2026-05-27 11:11:43 -07:00
Alexander Piskun
175e85466a
[Partner Nodes] feat: add Krea2 nodes (#14130)
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
2026-05-27 05:39:32 -07:00
Daxiong (Lin)
53eba227f5
chore: update workflow templates to v0.9.85 (#14134) 2026-05-27 05:32:58 -07:00
Alexander Piskun
0cce76d402
[Partner Nodes] feat: improve video references uploading for SeeDance 2 (#14098)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
* [Partner Nodes] feat: improve video references uploading for SeeDance 2

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* [Partner Nodes] hash video via memoryview to avoid memory copy

Signed-off-by: bigcat88 <bigcat88@icloud.com>

---------

Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-05-26 23:44:27 -07:00
Barish Ozbay
2072d3e46d
fix: Stop LTXVCropGuides leaving stray latent frames when guides share a start position (#13882) 2026-05-26 19:59:32 -07:00
comfyanonymous
e75a92c1b6
Add memory usage factor for lens model. (#14124)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-05-26 18:06:51 -07:00
comfyanonymous
d8d860a588
Closer memory usage factors for PID (#14123) 2026-05-26 18:04:55 -07:00
Jukka Seppänen
28f4ef277c
feat: Support NVIDIA PixelDiT and PiD (CORE-201) (#14103) 2026-05-26 17:50:14 -07:00
Matt Miller
921775704c
openapi: document QueueManageResponse body on POST /api/queue (#14117)
* openapi: document QueueManageResponse body on POST /api/queue

The Cloud runtime returns a JSON body from POST /api/queue describing which
prompts were deleted and whether the queue was cleared. The spec previously
declared a bare 200 with no schema, so generated clients had no type for the
response.

Adds a QueueManageResponse schema ({deleted, cleared}) and references it from
the 200 response. Tagged x-runtime: [cloud] with a [cloud-only] description:
local ComfyUI returns an empty 200 body, so both fields are nullable.

* openapi: fix GET /api/hub/labels response to the label-catalog shape (#14118)

* openapi: fix GET /api/hub/labels response to the label-catalog shape

GET /api/hub/labels returns the catalog of available labels you can filter by,
which the Cloud runtime serves as {labels: HubLabelInfo[]} (slug name,
display_name, and a type category: tag/model/custom_node).

The spec had this operation returning a bare array of HubLabel ({id, name,
color}) — that schema models the label chips attached to a published workflow
(HubWorkflow.labels), a different object. The catalog schema (HubLabelInfo)
already existed but was unreferenced.

Repoints the 200 response to a new HubLabelListResponse wrapper over the
existing HubLabelInfo. HubLabel is unchanged and still used by
HubWorkflow.labels. Endpoint remains x-runtime: [cloud].

* openapi: add Cloud-runtime fields (workflow_id, execution_error) to JobEntry (#14119)

* openapi: add Cloud-runtime fields workflow_id, execution_error to JobEntry

The Cloud runtime returns two additional fields on JobEntry that the spec
didn't declare:

- workflow_id: UUID of the Cloud workflow entity the job is associated with
- execution_error: structured ComfyUI execution error for failed jobs
  (reuses the existing ExecutionError schema)

Both tagged x-runtime: [cloud] with [cloud-only] descriptions; local ComfyUI
does not populate them.

* openapi: document Cloud-runtime request fields on POST /api/assets/export (#14120)

The Cloud runtime accepts three request fields on /api/assets/export that the
spec didn't declare:

- job_ids: include all assets associated with the given jobs
- naming_strategy: how to name files in the ZIP (enum, default group_by_job_time)
- job_asset_name_filters: optional per-job asset-name allowlist

Also drops asset_ids from required: the runtime supports exporting by job_ids
alone, so neither field is individually required.

/api/assets/export is already x-runtime: [cloud]; these are plain field
additions under that endpoint-level tag.
2026-05-26 16:25:20 -07:00
Jukka Seppänen
f9f54cae42
Lens: some cleanup (#14112)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
* Lens: remove redundant memory optimization
2026-05-26 10:32:53 +03:00
Jukka Seppänen
41812fa0ac
feat: Microsoft Lens support (CORE-248) (#14077) 2026-05-25 23:01:51 -07:00
Ivan Zorin
57414dadfe
fix: cross-attention AdaLN scale, shift, sigma parameters calculation (#14097) 2026-05-25 20:07:09 -07:00
Jedrzej Kosinski
88956e77af
multigpu: use unet_manual_cast for SelectModelDevice compute dtype (#14108) 2026-05-25 20:03:37 -07:00
comfyanonymous
da49b7d0b6
Remove useless annotations imports. (#14105) 2026-05-25 19:23:29 -07:00
Jedrzej Kosinski
0a2dd86e78
MultiGPU Work Units For Accelerated Sampling (CORE-184) (#7063) 2026-05-25 18:26:40 -07:00
Daxiong (Lin)
04879a8113
Add new open-source model and built-in tool blueprints (#13980)
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
2026-05-25 12:25:16 -07:00
Matt Miller
6de7fc063b
Emit hash alongside asset_hash on all Asset responses (#13739)
* Emit `hash` alongside `asset_hash` on all Asset responses

Add a `hash` field to the Asset response schema that carries the same
value as the existing `asset_hash` field. Both fields are now populated
in _build_asset_response, so every Asset-returning endpoint (GET, POST,
PUT) includes both.

No existing fields are removed. Tests updated to assert both fields.

Co-authored-by: Matt Miller <MillerMedia@users.noreply.github.com>

* Tighten hash field tests and DRY response builder

- Extract assert_hash_fields_consistent() helper that verifies presence
  parity and value equality, replacing body.get()-based assertions that
  treated missing keys and explicit nulls identically.
- Conftest seeded_asset fixture and seed-asset list assertions now check
  key absence directly, so a regression that surfaces null fields would
  be caught (validates exclude_none behavior).
- DRY duplicate hash expression in _build_asset_response.
- Add list-endpoint coverage asserting hash is present and consistent on
  populated assets.
- Add schema-level test asserting AssetCreated inherits the hash field
  from Asset, guarding against future inheritance drift.

---------

Co-authored-by: Matt Miller <MillerMedia@users.noreply.github.com>
Co-authored-by: guill <jacob.e.segal@gmail.com>
2026-05-25 11:21:35 -07:00
Daxiong (Lin)
a4141a0f5a
chore: update embedded docs to v0.5.1 (#14101) 2026-05-26 01:57:18 +08:00
comfyanonymous
0077d78cbf
Save Image advanced node (CORE-32) (#13850)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-05-24 23:01:34 -04:00
Talmaj
63bcaec5d1
Add colored logs (#14036) 2026-05-25 10:00:55 +08:00
rattus
b30e980a20
cache-ram: lower thresholds (#14089)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Use the RAM right up to the wire as the community is bit accustomed too.

This trades off headroom for the case where large chunky intermediates
arrive and potenitally hits pagefile/swap, but a lot of people have
"it just fits" workflows out there, so strike a compromise with
75->90%.

Disable the incative cache for all but the very high RAM users.
2026-05-24 15:26:50 -07:00
rattus
39f963b4b0
mark loads to pins as cold immediately (#14088)
This does the posix_fadvise to kick pins out of the disk cache (to
avoid a double copy in RAM).
2026-05-24 15:25:59 -07:00
Matt Miller
ea62dc11c9
openapi: fix invalid BillingStatus schema (object + enum hybrid) (#14071)
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Generate Pydantic Stubs from api.comfy.org / generate-models (push) Has been cancelled
2026-05-24 10:58:35 +08:00
Robin Huang
32a7092c52
fix: correct description of where compiled FE files live (#14013) 2026-05-24 10:48:31 +08:00
comfyanonymous
08d809d128
Fix --use-flash-attention ignored when xformers installed. (#14083) 2026-05-23 17:44:28 -07:00
Comfy Org PR Bot
0af123022d
Bump comfyui-frontend-package to 1.44.19 (#14074) 2026-05-24 08:27:52 +08:00
comfyanonymous
d80fcafee7
Remove dead code. (#14072)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-05-22 19:56:36 -07:00
Matt Miller
187442cca4
openapi: add enum values + FeedbackRequest schema for cloud cutover (PR E) (#14070)
* openapi: add enum values + FeedbackRequest schema for cloud cutover (PR E)

Adds missing cloud-runtime enum values to vendor schemas that the
cloud runtime emits but vendor declared as plain strings.

Changes:
  - JobEntry.status: enum [pending, in_progress, completed, failed, cancelled]
  - JobDetailResponse.status: same enum
  - BillingStatus: enum [awaiting_payment_method, pending_payment, paid,
      payment_failed, inactive]
  - FeedbackRequest schema added (with type enum)
  - /api/feedback POST: requestBody now $refs FeedbackRequest

All cloud-runtime-emitted; no impact on OSS-local semantics.

Identified via Comfy-Org/cloud's TestCutoverSafe gate (BE-1106) as
the remaining schema-level divergences after PRs A-D landed and got
synced.

* openapi: add type enum to Workspace schema (cutover follow-up)

Cloud's Workspace runtime shape includes a 'type' field with enum
[personal, team] that vendor's Workspace was missing. Cloud handlers
reference the generated ingest.WorkspaceType Go enum.

Same kind of surgical addition as JobEntry.status / BillingStatus /
JobDetailResponse.status in this PR — adds cloud-runtime field to
existing vendor schema.
2026-05-22 18:23:22 -07:00
Matt Miller
c3c881f37b
openapi: rename cloud-side response schemas to match runtime (PR D) (#14065)
* openapi: rename cloud-side response schemas to match runtime (PR D)

Follow-up to the BE-1106 stack (#14060/61/63). Cloud's Go handlers
reference response schemas by name (e.g., ingest.WorkflowResponse,
ingest.SubscribeResponse), but vendor's matching operations were
declaring those responses against differently-named vendor-side
schemas (CloudWorkflow, BillingSubscription, etc.). After the stack
landed, schemas like WorkflowResponse exist in vendor but weren't
referenced by any path, so codegen pruned the unreferenced types.

This PR:
  1. Updates 34 operation $refs in cloud-runtime paths to point to
     the schema names cloud's handlers expect (e.g., CloudWorkflow →
     WorkflowResponse on /api/workflows/{workflow_id}).
  2. Adds 12 cloud-only schemas that weren't in vendor yet but are
     referenced by these renames (e.g., SubscribeResponse,
     CancelSubscriptionResponse, BillingOpStatusResponse). Each
     copied verbatim from Comfy-Org/cloud's hand-written ingest spec
     and tagged x-runtime: [cloud] with a [cloud-only] description
     prefix.

Schema renames span the same domains as the operationId renames in
PR A: billing/subscriptions (7 schemas), workflows (5), userdata (3),
jobs (2), hub (2), history (2), auth/workspace (4), and misc cloud
endpoints (9).

Convergent safety check after this lands (against cloud's
TestCutoverSafe gate, BE-1106):
  Pre-PR D:   205 missing handler refs
  Post-PR D:  105 missing handler refs (-49%)
  Cumulative since the original 938-ref baseline: -89%

The remaining 105 are a Phase 3 follow-up (response headers,
text/plain responses, codegen-derived enum sub-types, and a small
set of inline-response-schema operations that vendor declares
inline where cloud has named-schema $refs).

* openapi: drop PR-label comment from new schemas block

PR-internal labels don't belong in committed code — future readers
won't know what 'PR D' means and the marker stops being useful the
moment this PR merges.
2026-05-22 16:34:52 -07:00
Matt Miller
7984a6a38e
openapi: rename 55 cloud-side operationIds to match runtime (PR A of 3) (#14060)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
* openapi: rename 55 cloud-side operationIds to match runtime handlers

For the 55 operations below, vendor's operationId did not match the
name cloud's runtime handlers expect. Generated types from vendor
therefore had different names (e.g. CreateSubscription200JSONResponse)
than what cloud handlers reference (Subscribe200JSONResponse), which
blocks the post-cutover combined-spec codegen.

All 55 renames target the cloud-runtime-authoritative name. Several
of these endpoints are shared concepts (queue, settings, userdata,
object_info) that OSS local also serves — the rename aligns vendor
with the longstanding cloud handler-side convention to unblock the
shared codegen. No request/response *shape* changes in this PR; only
operationId labels.

Notable categories:
  - Billing/subscriptions: 7 renames (subscribe, getBillingPlans, ...)
  - Workspace + workflows: 13 renames (createWorkflow, ...)
  - Hub: 3 renames
  - Auth/users: 5 renames
  - Shared OSS surface (settings, queue, view, userdata): 12 renames
  - Misc cloud-only: 15 renames

Identified via Comfy-Org/cloud's TestCutoverSafe build-safety gate
(BE-1106), which compares handler type references against codegen
output from the combined spec.

* fix(openapi): resolve getHistory operationId collision

Spectral flagged: both /api/history (OSS local) and /api/history_v2
(cloud) had operationId 'getHistory' after the rename. Rename vendor's
/api/history to 'getPromptHistory' to disambiguate. Cloud's runtime
denies /api/history at the overlay level so combined codegen is
unaffected by this change.

* openapi: add 41 cloud-runtime schemas to components.schemas (PR B of 3) (#14061)

* openapi: add 41 cloud-runtime schemas to components.schemas (cutover prep)

Adds schemas that exist in Comfy-Org/cloud's hand-written ingest spec
but not yet in this vendored OSS spec. All tagged x-runtime: [cloud]
per the field-drift convention and prefixed with [cloud-only] in the
description.

These schemas are referenced by cloud's Go handlers via the generated
ingest.<Schema> Go type names. Codegen from the vendored spec didn't
produce those types because the schemas weren't declared here. Adding
them unblocks the post-cutover combined-spec codegen.

Schemas added (alphabetical):
  AssetDownloadResponse, AssetMetadataResponse, BillingBalanceResponse,
  BillingPlansResponse, BillingStatusResponse, GetUserDataResponseFull,
  HistoryDetailEntry, HistoryDetailResponse, HistoryResponse,
  HubLabelInfo, HubProfileSummary, HubWorkflowListResponse,
  HubWorkflowStatus, HubWorkflowSummary, HubWorkflowTemplateEntry,
  JobStatusResponse, JobsListResponse, LabelRef, LogsResponse, Member,
  OAuthRegisterBadRequestResponse, PendingInvite, Plan, PlanAvailability,
  PlanAvailabilityReason, PlanSeatSummary, PreviewPlanInfo,
  PreviewSubscribeResponse, PublishedWorkflowDetail, SecretResponse,
  SubscriptionDuration, SubscriptionTier, UserDataResponseFull,
  ValidationError, ValidationResult, WorkflowForkedFrom, WorkflowResponse,
  WorkflowVersionContentResponse, WorkspaceAPIKeyInfo, WorkspaceSummary,
  WorkspaceWithRole

Identified via Comfy-Org/cloud's TestCutoverSafe build-safety gate
(BE-1106). Companion to PR #14060 (operationId renames).

* fix(openapi): add BindingErrorResponse schema

OAuthRegisterBadRequestResponse references BindingErrorResponse but
that schema wasn't in the original add. Adding it now as a cloud-only
schema matching the cloud runtime's binding-error shape (single
'message' string field).

* openapi: add missing 4xx/5xx response bodies for cloud-emitting endpoints (#14063)

Vendor declares shared endpoints (e.g. /api/queue, /api/settings,
/api/assets/*, /api/billing/*) with success responses but is missing
many of the 4xx/5xx error response bodies that Comfy-Org/cloud's
runtime actually emits. Cloud's Go handlers reference the generated
ingest.Op<StatusCode>JSONResponse types for these missing statuses,
which currently fail to resolve when codegen runs against the
vendored spec.

This PR adds 237 response entries across 117 operations, restoring
the documented error responses that cloud emits. Bodies are copied
verbatim from Comfy-Org/cloud's hand-written ingest spec
(services/ingest/openapi.yaml) and reference a new ErrorResponse
schema also added in this PR (matches cloud's {code, message} runtime
shape, tagged x-runtime: [cloud]).

ErrorResponse is intentionally separate from the existing CloudError
schema. CloudError's shape ({error}) describes one runtime; cloud
emits a different shape ({code, message}). Existing CloudError refs
in vendor are untouched; new cloud-emitting error references use
ErrorResponse.

Identified via Comfy-Org/cloud's TestCutoverSafe build-safety gate
(BE-1106). Companion to PR #14060 (operationId renames) and PR #14061
(cloud-only schema additions).
2026-05-22 16:15:18 -07:00
comfyanonymous
e75b739c1d
Delete the source branch after doing the backport. (#14062) 2026-05-22 15:47:03 -07:00
Matt Miller
112fcd5f3b
openapi: align response declarations with implementation (5 endpoints) (#14058)
* openapi: align response declarations with implementation (5 endpoints)

- POST /api/assets/download: replace 200 with 202 + tracking-task body
  (endpoint runs asynchronously and returns task_id/status/message).
- POST /api/assets/export: same 200 → 202 + tracking-task body.
- POST /api/assets/from-workflow: change 201 → 200 (handler responds 200,
  not 201; no Location header emitted).
- POST /api/feedback: change 200 → 201 (creates a feedback record).
- /api/jobs and /api/jobs/{job_id}: change timestamp fields from
  type: number to type: integer + format: int64. Values are Unix
  milliseconds — number causes oapi-codegen to emit float64, losing
  precision and producing the wrong Go type. Affected fields:
  create_time, update_time, execution_start_time, execution_end_time.

Verification: each change reflects what the endpoint observably returns;
no handler changes required. Backwards-compatible for existing clients
(integer is a subset of number; status code shifts within 2xx).

* openapi: align asset download/export 202 status enum with runtime + sibling schemas

CodeRabbit caught a vocabulary mismatch: the two new 202 response schemas
declared `[pending, running, completed, failed]` while the rest of the same
spec uses `[created, running, completed, failed]` for the identical task
lifecycle (download/export progress WebSocket events, /api/tasks, TaskEntry,
TaskResponse — 4 sites total). Cloud's runtime emits `created` on initial
creation (AssetDownloadResponseStatusCreated; task.Status sourced from the
DB enum whose initial value is Created). `pending` would have introduced a
fifth, contradictory vocabulary for the same lifecycle and pushed the spec
further from the implementation it is meant to align with.

Followup tracked separately: extract a shared TaskStatus enum so all five
sites move in lockstep instead of needing per-site edits.
2026-05-22 14:31:43 -07:00
Alexander Piskun
1579bbb52d
[Partner Nodes] add new Rodin2.5 nodes (#14051)
* [Partner Nodes] add new Rodin2.5 nodes

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* [Partner Nodes] fixed Quality Mesh Options

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* [Partner Nodes] fix: remove non-supported "usdz"

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* [Partner Nodes] fix: always pass seed to server

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* [Partner Nodes] fix: set the default "material" value to "Shaded"

Signed-off-by: bigcat88 <bigcat88@icloud.com>

---------

Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-05-22 09:07:21 -07:00
Alexis Rolland
93888ae8e3
Move logic nodes into utils category (#14033)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-05-22 13:32:08 +08:00
Pauan
38ebc19037
Adding in And, Or, and Not nodes. (#14004) 2026-05-22 11:01:12 +08:00
comfyanonymous
9650570378
Update Discord invite link in README.md (#14045) 2026-05-21 19:52:38 -07:00
rattus
f48c32871b
fe: Consolidate warnings (#13970) 2026-05-22 10:18:13 +08:00
comfyanonymous
8edff549e3
Update backport workflow to use commit SHA input (#14043) 2026-05-21 18:22:47 -07:00
comfyanonymous
8fecef0686
Add validation for source branch in backport workflow (#14042) 2026-05-21 16:39:19 -07:00
Jedrzej Kosinski
5d681a5420
Fix SIGPIPE false negative in backport release validation (#14041) 2026-05-21 16:29:08 -07:00
comfyanonymous
32e58393b8
Add backport release workflow. (#14038)
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
2026-05-21 14:49:55 -07:00
Alexander Piskun
b293f8cefd
[Partner Nodes] add widget for automatic upscaling for the ByteDance2Reference node (#14032)
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-05-21 11:58:03 -07:00
Daxiong (Lin)
2ca1480f91
chore: update workflow templates to v0.9.82 (#14034) 2026-05-21 11:48:20 -07:00
Alexander Piskun
6ecf5eca7a
[Partner Nodes] add OpenRouter LLM node (#14007)
* [Partner Nodes] add reasoning widget to Anthropic node

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* [Partner Nodes] add new OpenRouterLLM node

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* [Partner Nodes] fix passing images to Grok LLM

Signed-off-by: bigcat88 <bigcat88@icloud.com>

---------

Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-05-21 11:36:11 -07:00
rattus
03e511862e
Fix reshaping lora application (#14031)
* ModelPatcherDyanmic: purge stale vbar allocs on force cast

* ModelPatcherDynamic: restore backups before load

If doing a clean reload, mutative changes (lora application) could be
applied on-top of the already loaded weight. Restore from backup
unconditionally so that the new load is clean.
2026-05-21 09:47:16 -07:00
Edoardo Carmignani
aab41a9ddb
fix(lanczos): correct dimension transposition for single-channel tensors (#12679) 2026-05-21 23:47:20 +08:00
Alexis Rolland
4259a0c7c3
Update MoGe nodes display names, search aliases and descriptions (#14030)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-05-21 16:50:09 +08:00
Alexis Rolland
af3d9b60af
chore: Dataset nodes clean-up (CORE-237) (#14002) 2026-05-21 15:14:16 +08:00
Alexis Rolland
7b7c5fed7c
Update MediaPipe nodes to standardize with existing code base (CORE-242) (#14025) 2026-05-21 14:39:30 +08:00
Matt Miller
1668aaf037
openapi: remove cloud-only job_ids query param from GET /api/assets (#14016)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
The job_ids query parameter on GET /api/assets is tagged x-runtime:
[cloud] and only exists for cloud's variant of this endpoint. Cloud
removed all consumers and the cloud-side handler/codegen/tests in
Comfy-Org/cloud#3778. With cloud no longer accepting this parameter,
the [cloud-only] documentation here is wrong — drop it so the daily
sync to cloud/services/ingest/vendor/openapi.yaml propagates the
removal.
2026-05-20 21:32:08 -07:00
Matt Miller
ea174d3f12
fix(openapi): correct POST /api/assets/import to importPublishedAssets (#14027)
The operation at POST /api/assets/import was defined as `importAssets`
with a URL-list body shape, but no runtime actually serves that
operation at this path. The cloud runtime serves a different operation
here — `importPublishedAssets` — which imports published-workflow
assets into the caller's library by ID, not by URL.

Cloud's URL-based asset ingestion lives at separate paths
(POST /assets/download + GET /assets/remote-metadata) tracked
elsewhere; nothing in this PR affects that work.

Changes:

- Replace the operation at POST /api/assets/import with
  `importPublishedAssets`, taking ImportPublishedAssetsRequest
  (published_asset_ids + optional share_id) and returning
  ImportPublishedAssetsResponse (list of AssetInfo).
- Remove the unused AssetImportRequest component schema (no other
  references in the spec).
- Operation and schemas tagged x-runtime: [cloud] with [cloud-only]
  description prefix, matching the existing convention for
  cloud-runtime-only operations elsewhere in the spec.

Spectral lint passes (0 errors); the two hint-level findings on
the spec are pre-existing and unrelated.

No FE consumer references AssetImportRequest today; this is a pure
spec correction to match what the cloud runtime actually serves.
2026-05-20 21:28:16 -07:00
Matt Miller
9f9b32ed97
feat: add OAuth 2.1 + RFC 7591 DCR endpoints to openapi.yaml (#14026)
Add the OAuth 2.1 authorization flow and RFC 7591 Dynamic Client
Registration endpoints to the shared spec, alongside the existing
auth-tagged operations (/api/auth/session, /api/auth/token,
/.well-known/jwks.json). All tagged x-runtime: [cloud] with a
[cloud-only] description prefix, following the established
convention for cloud-runtime-only operations.

Endpoints:

- GET  /.well-known/oauth-authorization-server  (RFC 8414 metadata)
- GET  /.well-known/oauth-protected-resource    (RFC 9728 metadata)
- GET  /oauth/authorize                         (consent challenge)
- POST /oauth/authorize                         (consent submission)
- POST /oauth/token                             (RFC 6749 §3.2)
- POST /oauth/register                          (RFC 7591 §3.1 DCR)

Component schemas added:

- OAuthAuthorizationServerMetadata
- OAuthProtectedResourceMetadata
- OAuthConsentChallenge, OAuthConsentChallengeWorkspace
- OAuthAuthorizeRedirectResponse
- OAuthTokenResponse, OAuthTokenError
- OAuthRegisterRequest, OAuthRegisterResponse, OAuthRegisterError

These endpoints are implemented in the cloud runtime today and
are called by browser frontends rendering the consent UI and by
MCP-spec-compliant clients (Claude Desktop, Cursor, etc.) doing
auto-discovery + self-registration. Documenting them in the
shared spec lets the cloud frontend generate types directly from
this spec instead of maintaining a parallel definition.

Spectral lints clean (0 errors). The hint-level findings on
OAuthTokenError / OAuthRegisterError ("standard error schema")
match the same hint on CloudError — these are protocol-specific
RFC-shaped errors, not generic application errors.
2026-05-20 21:22:12 -07:00
comfyanonymous
95fdc6cf91
Repo security stuff. (#14019) 2026-05-20 17:17:55 -07:00
rattus
5aa5ccc9e0
Multi-threaded load of models from disk (big load time speedups & Offload to disk) (CORE-43,CORE-152,CORE-164,CORE-165,CORE-117) (#13802)
* model_management: disable non-dynamic smart memory

Disable smart memory outright for non dynamic models.

This is a minor step towards deprecation of --disable-dynamic-vram
and the legacy ModelPatcher.

This is needed for estimate-free model development, where new models
can opt-out of supplying a memory estimate and not have to worry
about hard VRAM allocations due to legacy non-dynamic model patchers

This is also a general stability increase for a lot of stray use cases
where estimates may still be off and going forward we are not going
to accurately maintain such estimates.

* pinned_memory: implement with aimdo growable buffer

Use a single growable buffer so we can do threaded pre-warming on
pinned memory.

* mm: use aimdo to do transfer from disk to pin

Aimdo implements a faster threaded loader.

* Add stream host pin buffer for AIMDO casts

Introduce per-offload-stream HostBuffer reuse for pinned staging,
include it in cast buffer reset synchronization.

Defer actual casts that go via this pin path to a separate pass
such that the buffer can be allocated monolithically (to avoid
cudaHostRegister thrash).

* remove old pin path

* Implement JIT pinned memory pressure

Replace the predictive pin pressure mechanism with JIT PIN memory
pressure.

* LowVRAMPatch: change to two-phase visit

* lora: re-implement as inplace swiss-army-knife operation

* prepare for multiple pin sets

* implement pinned loras

* requirements: comfy-aimdo 0.4.0

* ops: remove unused arg

This was defeatured in aimdo iteration

* ops: sync the CPU with only the offload stream activity

This was syncing with the offload stream which itself is synced with the
compute stream, so this was syncing CPU with compute transitively. Define
the event to sync it more gently.

* pins: implement freeing intermediate for pinned memory

Pinning is more important than inactive intermediates and the stream
pin buffer is more important than even active intermediates.

* execution: implement pin eviction on RAM presure

Add back proper pin freeing on RAM pressure

* implement pin registration swaps

Uncap the windows pins from 50% by extending the pool and have a pressure
mechanism to move the pin reservations om demand.

This unfortunately implies a GPU sync to do the freeing so significant
hysterisis needs to be added to consolidate these pressure events.

* cli_args/execution: Implement lower background cache-ram threshold

Limit the amount of RAM background intermediates can use, so that
switching workflows doesn't degrade performance too much.

* make default

* bump aimdo

* model-patcher: force-cast tiny weights

Flux 2 gets crazy stalls due to a mix of tiny and giant weights
creating lopsided steam buffer rotations which creates stalls.

* ops: refactor in prep for chunking

* mm: delegate pin-on-the-way to aimdo

Aimdo is able to chunk and slice this on the way for better CPU->GPU
overlap. The main advantage is the ability to shorten the bus contention
window between previous weight transfer and the next weights vbar
fault.

* bump aimdo

* pinning updates

* specify hostbuf max allocation size

There a signs of virtual memory exhaustion on some linux systems when
throwing 128GB for every little piece. Pass the actual to save aimdo
from over-estimates

* tests: update execution tests for caching

The default caching changed to ram-cache so update these tests
accordingly.

Remove the LRU 0 test as this also falls through to RAM cache.
2026-05-20 17:03:58 -07:00
Jukka Seppänen
4d6a058bf1
feat: MediaPipe face detection (CORE-235) (#14009)
* Initial mediapipe face detection support

* Update face_geometry.py

* Account for diff sized batch input

* Model folder placeholder
2026-05-20 16:07:48 -07:00
238 changed files with 39598 additions and 4283 deletions

519
.github/workflows/backport_release.yaml vendored Normal file
View File

@ -0,0 +1,519 @@
name: Backport Release
on:
workflow_dispatch:
inputs:
commit:
description: 'Full 40-char SHA of the tip commit of the backport source branch (the PR head commit that passed tests). The branch is resolved from this SHA and must be unique.'
required: true
type: string
permissions:
contents: read
pull-requests: read
checks: read
jobs:
backport-release:
name: Create backport release
runs-on: ubuntu-latest
environment: backport release
steps:
- name: Generate GitHub App token
id: app-token
uses: actions/create-github-app-token@bcd2ba49218906704ab6c1aa796996da409d3eb1
with:
app-id: ${{ secrets.FEN_RELEASE_APP_ID }}
private-key: ${{ secrets.FEN_RELEASE_PRIVATE_KEY }}
- name: Checkout repository
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd
with:
token: ${{ steps.app-token.outputs.token }}
fetch-depth: 0
fetch-tags: true
- name: Configure git
run: |
git config user.name "fen-release[bot]"
git config user.email "fen-release[bot]@users.noreply.github.com"
- name: Resolve source branch from commit SHA
id: resolve
env:
SOURCE_COMMIT: ${{ inputs.commit }}
DEFAULT_BRANCH: ${{ github.event.repository.default_branch }}
run: |
set -euo pipefail
# Require a full 40-char lowercase-hex SHA. Short SHAs are ambiguous
# and we will be comparing this value against API responses (PR head
# SHA, ref tips) that always return the full form.
if [[ ! "${SOURCE_COMMIT}" =~ ^[0-9a-f]{40}$ ]]; then
echo "::error::Input commit '${SOURCE_COMMIT}' is not a full 40-char lowercase hex SHA."
exit 1
fi
# Fetch all remote branches so we can search for which one(s) point
# at this SHA. `actions/checkout` with fetch-depth: 0 fetches full
# history of the checked-out ref but does not necessarily populate
# every refs/remotes/origin/*, so do it explicitly.
git fetch --prune origin '+refs/heads/*:refs/remotes/origin/*'
# Verify the commit actually exists in this repo's object DB.
if ! git cat-file -e "${SOURCE_COMMIT}^{commit}" 2>/dev/null; then
echo "::error::Commit ${SOURCE_COMMIT} was not found in the repository."
exit 1
fi
# Find every remote branch whose tip == SOURCE_COMMIT. Exactly one
# branch must point at it. If zero, the commit isn't anyone's tip
# (likely stale, force-pushed past, or never the PR head). If more
# than one, the (branch -> SHA) mapping is ambiguous and we refuse
# to guess — the operator must give us a unique branch to release.
mapfile -t matching_branches < <(
git for-each-ref \
--format='%(refname:strip=3)' \
--points-at="${SOURCE_COMMIT}" \
refs/remotes/origin/ \
| grep -vx 'HEAD' || true
)
if [[ "${#matching_branches[@]}" -eq 0 ]]; then
echo "::error::No branch on origin has ${SOURCE_COMMIT} as its tip."
echo "::error::Either the branch was updated after you copied this SHA, or this commit was never the head of a branch."
exit 1
fi
if [[ "${#matching_branches[@]}" -gt 1 ]]; then
echo "::error::More than one branch on origin has ${SOURCE_COMMIT} as its tip; cannot pick one:"
for b in "${matching_branches[@]}"; do
echo "::error:: - ${b}"
done
echo "::error::Refusing to proceed with an ambiguous source branch."
exit 1
fi
source_branch="${matching_branches[0]}"
if [[ "${source_branch}" == "${DEFAULT_BRANCH}" ]]; then
echo "::error::Source branch must not be the default branch ('${DEFAULT_BRANCH}')."
exit 1
fi
echo "Resolved commit ${SOURCE_COMMIT} to branch '${source_branch}'."
echo "source_branch=${source_branch}" >> "$GITHUB_OUTPUT"
- name: Determine latest stable release
id: latest
env:
GH_TOKEN: ${{ steps.app-token.outputs.token }}
run: |
set -euo pipefail
# List all tags matching vMAJOR.MINOR.PATCH and pick the highest by numeric
# comparison of each component. We DO NOT use `sort -V` because it treats
# v0.19.99 as higher than v0.20.1.
latest_tag="$(
git tag --list 'v[0-9]*.[0-9]*.[0-9]*' \
| grep -E '^v[0-9]+\.[0-9]+\.[0-9]+$' \
| awk -F'[v.]' '{ printf "%010d %010d %010d %s\n", $2, $3, $4, $0 }' \
| sort -k1,1n -k2,2n -k3,3n \
| tail -n1 \
| awk '{print $4}'
)"
if [[ -z "${latest_tag}" ]]; then
echo "::error::No stable release tags (vMAJOR.MINOR.PATCH) were found."
exit 1
fi
# Parse components
ver="${latest_tag#v}"
major="${ver%%.*}"
rest="${ver#*.}"
minor="${rest%%.*}"
patch="${rest#*.}"
new_patch=$((patch + 1))
new_version="v${major}.${minor}.${new_patch}"
release_branch="release/v${major}.${minor}"
latest_sha="$(git rev-list -n 1 "refs/tags/${latest_tag}")"
echo "latest_tag=${latest_tag}" >> "$GITHUB_OUTPUT"
echo "latest_sha=${latest_sha}" >> "$GITHUB_OUTPUT"
echo "major=${major}" >> "$GITHUB_OUTPUT"
echo "minor=${minor}" >> "$GITHUB_OUTPUT"
echo "patch=${patch}" >> "$GITHUB_OUTPUT"
echo "new_version=${new_version}" >> "$GITHUB_OUTPUT"
echo "new_version_no_v=${major}.${minor}.${new_patch}" >> "$GITHUB_OUTPUT"
echo "release_branch=${release_branch}" >> "$GITHUB_OUTPUT"
echo "Latest stable release: ${latest_tag} (${latest_sha})"
echo "New version will be: ${new_version}"
echo "Release branch: ${release_branch}"
- name: Validate source branch is cut directly from the latest stable release
env:
SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }}
SOURCE_COMMIT: ${{ inputs.commit }}
LATEST_TAG_SHA: ${{ steps.latest.outputs.latest_sha }}
LATEST_TAG: ${{ steps.latest.outputs.latest_tag }}
run: |
set -euo pipefail
# Use the user-provided SHA directly rather than re-resolving the branch
# tip — the resolve step already proved the branch tip equals SOURCE_COMMIT,
# and pinning to the SHA here makes the rest of the job TOCTOU-safe against
# someone pushing to the branch mid-run.
source_sha="${SOURCE_COMMIT}"
# Walking first-parent from the source tip must reach LATEST_TAG_SHA.
# We capture rev-list into a variable and grep against a here-string
# rather than piping `rev-list | grep -q`: under `set -o pipefail`,
# `grep -q` would exit on first match and SIGPIPE the still-streaming
# `rev-list`, propagating exit 141 as a spurious "not found".
first_parent_chain="$(git rev-list --first-parent "${source_sha}")"
if ! grep -Fxq "${LATEST_TAG_SHA}" <<< "${first_parent_chain}"; then
echo "::error::Source branch '${SOURCE_BRANCH}' is not cut from '${LATEST_TAG}'."
echo "::error::Its first-parent history does not include ${LATEST_TAG_SHA}."
exit 1
fi
# Additionally, every commit added on top of the tag (the set we are
# about to publish) must itself be a descendant of the tag along
# first-parent — i.e. no sibling commits from master sneak in via a
# non-first-parent path. Enforce by requiring that the symmetric
# difference is empty in one direction: commits in source that are
# NOT first-parent-reachable from source starting at the tag.
# We do this by intersecting:
# A = commits reachable from source but not from tag (full DAG)
# B = commits on the first-parent chain from source down to tag
# and requiring A == B.
all_added="$(git rev-list "${LATEST_TAG_SHA}..${source_sha}" | sort)"
first_parent_added="$(
git rev-list --first-parent "${LATEST_TAG_SHA}..${source_sha}" | sort
)"
if [[ "${all_added}" != "${first_parent_added}" ]]; then
echo "::error::Source branch '${SOURCE_BRANCH}' contains commits not on its first-parent chain from '${LATEST_TAG}'."
echo "::error::This usually means the branch was cut from master (not from the tag) or contains a merge from master."
echo "Commits reachable but not on first-parent chain:"
comm -23 <(printf '%s\n' "${all_added}") <(printf '%s\n' "${first_parent_added}") \
| while read -r sha; do
echo " $(git log -1 --format='%h %s' "${sha}")"
done
exit 1
fi
added_count="$(printf '%s\n' "${all_added}" | grep -c . || true)"
echo "Source branch is cut directly from ${LATEST_TAG} with ${added_count} commit(s) on top."
- name: Validate PR exists, is open, named correctly, has latest commit, and checks pass
env:
GH_TOKEN: ${{ steps.app-token.outputs.token }}
SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }}
SOURCE_COMMIT: ${{ inputs.commit }}
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
REPO: ${{ github.repository }}
run: |
set -euo pipefail
expected_title="ComfyUI backport release ${NEW_VERSION}"
# Find open PRs from this branch into master. The --state open filter
# is load-bearing: a closed/merged PR with passing checks must not be
# accepted as authorization for a new release.
pr_json="$(
gh pr list \
--repo "${REPO}" \
--state open \
--head "${SOURCE_BRANCH}" \
--base master \
--json number,title,headRefOid,state \
--limit 10
)"
pr_count="$(echo "${pr_json}" | jq 'length')"
if [[ "${pr_count}" -eq 0 ]]; then
echo "::error::No open PR found from '${SOURCE_BRANCH}' into 'master'. The PR must exist and be open."
exit 1
fi
# Pick the PR matching the expected title
pr_number="$(echo "${pr_json}" | jq -r --arg t "${expected_title}" '
map(select(.title == $t)) | .[0].number // empty
')"
pr_head_sha="$(echo "${pr_json}" | jq -r --arg t "${expected_title}" '
map(select(.title == $t)) | .[0].headRefOid // empty
')"
if [[ -z "${pr_number}" ]]; then
echo "::error::No open PR from '${SOURCE_BRANCH}' into 'master' is titled '${expected_title}'."
echo "Found PRs:"
echo "${pr_json}" | jq -r '.[] | " #\(.number): \(.title)"'
exit 1
fi
# The PR's current head commit must equal the SHA the operator gave us.
# This is what closes the door on releasing stale code: if anyone has
# pushed to the branch since the operator validated tests passed, the
# PR head will have advanced past SOURCE_COMMIT and we abort. (The
# resolve step already proved the branch tip == SOURCE_COMMIT; this
# ties that same SHA to the PR that authorizes the release.)
if [[ "${pr_head_sha}" != "${SOURCE_COMMIT}" ]]; then
echo "::error::PR #${pr_number} head commit is ${pr_head_sha}, but the operator-provided commit is ${SOURCE_COMMIT}."
echo "::error::The PR has new commits since this release was authorized. Re-run with the new head SHA after verifying its checks."
exit 1
fi
echo "Found open PR #${pr_number} titled '${expected_title}' at head ${pr_head_sha} (matches operator-provided commit)."
# Verify all check runs on the head commit have completed successfully.
# A check is considered passing if conclusion is success, neutral, or skipped.
checks_json="$(
gh api \
--paginate \
"repos/${REPO}/commits/${pr_head_sha}/check-runs" \
--jq '.check_runs[] | {name: .name, status: .status, conclusion: .conclusion}'
)"
if [[ -z "${checks_json}" ]]; then
echo "::error::No check runs found on PR head commit ${pr_head_sha}."
exit 1
fi
echo "Check runs on ${pr_head_sha}:"
echo "${checks_json}" | jq -s '.'
failing="$(echo "${checks_json}" | jq -s '
map(select(
.status != "completed"
or (.conclusion as $c
| ["success","neutral","skipped"]
| index($c) | not)
))
')"
failing_count="$(echo "${failing}" | jq 'length')"
if [[ "${failing_count}" -gt 0 ]]; then
echo "::error::One or more checks have not passed on PR head commit ${pr_head_sha}:"
echo "${failing}" | jq -r '.[] | " - \(.name): status=\(.status) conclusion=\(.conclusion)"'
exit 1
fi
echo "All checks have passed on ${pr_head_sha}."
- name: Prepare release branch
id: prepare
env:
GH_TOKEN: ${{ steps.app-token.outputs.token }}
REPO: ${{ github.repository }}
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
LATEST_TAG: ${{ steps.latest.outputs.latest_tag }}
LATEST_TAG_SHA: ${{ steps.latest.outputs.latest_sha }}
PATCH: ${{ steps.latest.outputs.patch }}
run: |
set -euo pipefail
# Try to fetch the release branch. If patch == 0, it shouldn't exist yet
# and we'll create it from the latest stable tag. If patch > 0, it must
# already exist and its tip must equal the latest stable tag commit (i.e.
# the previous patch release).
if git ls-remote --exit-code --heads origin "${RELEASE_BRANCH}" >/dev/null 2>&1; then
echo "Release branch '${RELEASE_BRANCH}' already exists on origin."
git fetch origin "refs/heads/${RELEASE_BRANCH}:refs/remotes/origin/${RELEASE_BRANCH}"
git checkout -B "${RELEASE_BRANCH}" "refs/remotes/origin/${RELEASE_BRANCH}"
current_tip="$(git rev-parse HEAD)"
if [[ "${current_tip}" != "${LATEST_TAG_SHA}" ]]; then
echo "::error::Release branch '${RELEASE_BRANCH}' tip (${current_tip}) is not at the latest stable release '${LATEST_TAG}' (${LATEST_TAG_SHA})."
echo "::error::Refusing to release on top of a divergent branch."
exit 1
fi
echo "branch_existed=true" >> "$GITHUB_OUTPUT"
else
if [[ "${PATCH}" != "0" ]]; then
echo "::error::Release branch '${RELEASE_BRANCH}' does not exist on origin, but the latest stable release '${LATEST_TAG}' has patch=${PATCH} (>0). This is inconsistent."
exit 1
fi
echo "Release branch '${RELEASE_BRANCH}' does not exist. Creating from ${LATEST_TAG}."
git checkout -B "${RELEASE_BRANCH}" "refs/tags/${LATEST_TAG}"
echo "branch_existed=false" >> "$GITHUB_OUTPUT"
fi
- name: Fast-forward merge source branch into release branch
env:
SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }}
SOURCE_COMMIT: ${{ inputs.commit }}
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
run: |
set -euo pipefail
# --ff-only guarantees no merge commit is created. If a fast-forward is
# not possible (i.e. the release branch has commits the source branch
# doesn't), the merge will fail and we abort. Because we already validated
# that the source branch is rooted on the latest stable tag, and the
# release branch tip equals that same tag, this fast-forward should
# always succeed for a well-formed backport branch.
#
# We merge the operator-provided SHA, not the branch ref, so a push to
# the branch in the window between resolve and now cannot smuggle new
# commits into the release.
if ! git merge --ff-only "${SOURCE_COMMIT}"; then
echo "::error::Cannot fast-forward '${RELEASE_BRANCH}' to ${SOURCE_COMMIT} (tip of '${SOURCE_BRANCH}'). A merge commit would be required. Aborting."
exit 1
fi
echo "Fast-forwarded '${RELEASE_BRANCH}' to ${SOURCE_COMMIT} (tip of '${SOURCE_BRANCH}')."
- name: Bump version files
env:
NEW_VERSION_NO_V: ${{ steps.latest.outputs.new_version_no_v }}
run: |
set -euo pipefail
if [[ ! -f comfyui_version.py ]]; then
echo "::error::comfyui_version.py not found in repo root."
exit 1
fi
if [[ ! -f pyproject.toml ]]; then
echo "::error::pyproject.toml not found in repo root."
exit 1
fi
# Replace the version string in comfyui_version.py.
# Expected format: __version__ = "X.Y.Z"
python3 - "$NEW_VERSION_NO_V" <<'PY'
import re, sys, pathlib
new = sys.argv[1]
p = pathlib.Path("comfyui_version.py")
src = p.read_text()
new_src, n = re.subn(
r'(__version__\s*=\s*[\'"])[^\'"]+([\'"])',
lambda m: f'{m.group(1)}{new}{m.group(2)}',
src,
count=1,
)
if n != 1:
sys.exit("Could not find __version__ assignment in comfyui_version.py")
p.write_text(new_src)
p = pathlib.Path("pyproject.toml")
src = p.read_text()
# Replace the first `version = "..."` inside [project] or [tool.poetry].
new_src, n = re.subn(
r'(?m)^(version\s*=\s*")[^"]+(")',
lambda m: f'{m.group(1)}{new}{m.group(2)}',
src,
count=1,
)
if n != 1:
sys.exit("Could not find version assignment in pyproject.toml")
p.write_text(new_src)
PY
echo "Updated version to ${NEW_VERSION_NO_V} in comfyui_version.py and pyproject.toml."
git --no-pager diff -- comfyui_version.py pyproject.toml
- name: Commit version bump and tag release
env:
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
run: |
set -euo pipefail
git add comfyui_version.py pyproject.toml
git commit -m "ComfyUI ${NEW_VERSION}"
if git rev-parse -q --verify "refs/tags/${NEW_VERSION}" >/dev/null; then
echo "::error::Tag ${NEW_VERSION} already exists locally."
exit 1
fi
git tag "${NEW_VERSION}"
- name: Verify tag does not already exist on origin
env:
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
run: |
set -euo pipefail
if git ls-remote --exit-code --tags origin "refs/tags/${NEW_VERSION}" >/dev/null 2>&1; then
echo "::error::Tag ${NEW_VERSION} already exists on origin. Aborting."
exit 1
fi
- name: Push release branch and tag
env:
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
run: |
set -euo pipefail
# Push the branch first, then the tag. Atomic-ish: if the branch push
# fails we never publish the tag.
git push origin "refs/heads/${RELEASE_BRANCH}:refs/heads/${RELEASE_BRANCH}"
git push origin "refs/tags/${NEW_VERSION}"
echo "Released ${NEW_VERSION} on ${RELEASE_BRANCH}."
- name: Delete remote source branch
env:
GH_TOKEN: ${{ steps.app-token.outputs.token }}
REPO: ${{ github.repository }}
SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }}
SOURCE_COMMIT: ${{ inputs.commit }}
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
DEFAULT_BRANCH: ${{ github.event.repository.default_branch }}
run: |
set -euo pipefail
# Belt-and-braces: the resolve step already refuses the default branch,
# but never delete the default or the release branch under any
# circumstances.
if [[ "${SOURCE_BRANCH}" == "${DEFAULT_BRANCH}" || "${SOURCE_BRANCH}" == "${RELEASE_BRANCH}" ]]; then
echo "::error::Refusing to delete '${SOURCE_BRANCH}' (matches default or release branch)."
exit 1
fi
# Delete the source branch on origin, but only if its tip is still the
# SHA we released from. If someone pushed new commits to it after we
# resolved it, leave it alone — those commits would be silently lost.
current_tip="$(git ls-remote origin "refs/heads/${SOURCE_BRANCH}" | awk '{print $1}')"
if [[ -z "${current_tip}" ]]; then
echo "Source branch '${SOURCE_BRANCH}' no longer exists on origin; nothing to delete."
exit 0
fi
if [[ "${current_tip}" != "${SOURCE_COMMIT}" ]]; then
echo "::warning::Source branch '${SOURCE_BRANCH}' tip (${current_tip}) no longer matches released commit (${SOURCE_COMMIT}). Leaving it in place."
exit 0
fi
git push origin --delete "refs/heads/${SOURCE_BRANCH}"
echo "Deleted remote branch '${SOURCE_BRANCH}'."
- name: Summary
if: always()
env:
NEW_VERSION: ${{ steps.latest.outputs.new_version }}
RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }}
LATEST_TAG: ${{ steps.latest.outputs.latest_tag }}
SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }}
SOURCE_COMMIT: ${{ inputs.commit }}
run: |
# SOURCE_BRANCH is empty if the resolve step never produced an output
# (e.g. the workflow failed in or before that step). Show a placeholder
# in that case so the summary table still renders cleanly.
source_branch_display="${SOURCE_BRANCH:-(unresolved)}"
{
echo "## Backport release"
echo ""
echo "| Field | Value |"
echo "|---|---|"
echo "| Source commit | \`${SOURCE_COMMIT}\` |"
echo "| Source branch | \`${source_branch_display}\` |"
echo "| Previous stable | \`${LATEST_TAG}\` |"
echo "| New version | \`${NEW_VERSION}\` |"
echo "| Release branch | \`${RELEASE_BRANCH}\` |"
} >> "$GITHUB_STEP_SUMMARY"

View File

@ -0,0 +1,24 @@
name: Detect Unreviewed Merge
# SOC 2 compliance — reusable workflow lives in Comfy-Org/github-workflows,
# tracking issues are filed in Comfy-Org/unreviewed-merges.
on:
push:
branches: [master]
concurrency:
group: detect-unreviewed-merge-${{ github.sha }}
cancel-in-progress: false
permissions:
contents: read
pull-requests: read
jobs:
detect:
uses: Comfy-Org/github-workflows/.github/workflows/detect-unreviewed-merge.yml@4d9cb6b87f953bb7cd69954280e1465fb9bd2040 # v1
with:
approval-mode: latest-per-reviewer
secrets:
UNREVIEWED_MERGES_TOKEN: ${{ secrets.UNREVIEWED_MERGES_TOKEN }}

View File

@ -1,2 +1,5 @@
# Admins
* @comfyanonymous @kosinkadink @guill @alexisrolland @rattus128 @kijai * @comfyanonymous @kosinkadink @guill @alexisrolland @rattus128 @kijai
/CODEOWNERS @comfyanonymous
/.ci/ @comfyanonymous
/.github/ @comfyanonymous

View File

@ -20,7 +20,7 @@
[website-url]: https://www.comfy.org/ [website-url]: https://www.comfy.org/
<!-- Workaround to display total user from https://github.com/badges/shields/issues/4500#issuecomment-2060079995 --> <!-- Workaround to display total user from https://github.com/badges/shields/issues/4500#issuecomment-2060079995 -->
[discord-shield]: https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fdiscord.com%2Fapi%2Finvites%2Fcomfyorg%3Fwith_counts%3Dtrue&query=%24.approximate_member_count&logo=discord&logoColor=white&label=Discord&color=green&suffix=%20total [discord-shield]: https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fdiscord.com%2Fapi%2Finvites%2Fcomfyorg%3Fwith_counts%3Dtrue&query=%24.approximate_member_count&logo=discord&logoColor=white&label=Discord&color=green&suffix=%20total
[discord-url]: https://www.comfy.org/discord [discord-url]: https://discord.com/invite/comfyorg
[twitter-shield]: https://img.shields.io/twitter/follow/ComfyUI [twitter-shield]: https://img.shields.io/twitter/follow/ComfyUI
[twitter-url]: https://x.com/ComfyUI [twitter-url]: https://x.com/ComfyUI
@ -433,7 +433,7 @@ See also: [https://www.comfy.org/](https://www.comfy.org/)
## Frontend Development ## Frontend Development
As of August 15, 2024, we have transitioned to a new frontend, which is now hosted in a separate repository: [ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend). This repository now hosts the compiled JS (from TS/Vue) under the `web/` directory. As of August 15, 2024, we have transitioned to a new frontend, which is now hosted in a separate repository: [ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend). The compiled JS files (from TS/Vue) are published to [pypi](https://pypi.org/project/comfyui-frontend-package) and installed as a dependency in ComfyUI.
### Reporting Issues and Requesting Features ### Reporting Issues and Requesting Features

View File

@ -160,10 +160,12 @@ def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResu
preview_url = None preview_url = None
else: else:
preview_url = _build_preview_url_from_view(result.tags, result.ref.user_metadata) preview_url = _build_preview_url_from_view(result.tags, result.ref.user_metadata)
asset_content_hash = result.asset.hash if result.asset else None
return schemas_out.Asset( return schemas_out.Asset(
id=result.ref.id, id=result.ref.id,
name=result.ref.name, name=result.ref.name,
asset_hash=result.asset.hash if result.asset else None, hash=asset_content_hash,
asset_hash=asset_content_hash,
size=int(result.asset.size_bytes) if result.asset else None, size=int(result.asset.size_bytes) if result.asset else None,
mime_type=result.asset.mime_type if result.asset else None, mime_type=result.asset.mime_type if result.asset else None,
tags=result.tags, tags=result.tags,

View File

@ -10,6 +10,7 @@ class Asset(BaseModel):
id: str id: str
name: str name: str
hash: str | None = None
asset_hash: str | None = None asset_hash: str | None = None
size: int | None = None size: int | None = None
mime_type: str | None = None mime_type: str | None = None

View File

@ -4,7 +4,6 @@ Tier 1: Filesystem metadata (zero parsing)
Tier 2: Safetensors header metadata (fast JSON read only) Tier 2: Safetensors header metadata (fast JSON read only)
""" """
from __future__ import annotations
import json import json
import logging import logging

View File

@ -1,5 +1,3 @@
from __future__ import annotations
import os import os
import folder_paths import folder_paths
import glob import glob

View File

@ -1,4 +1,3 @@
from __future__ import annotations
import argparse import argparse
import logging import logging
import os import os
@ -62,6 +61,8 @@ def get_comfy_package_versions():
def check_comfy_packages_versions(): def check_comfy_packages_versions():
"""Warn for every comfy* package whose installed version is below requirements.txt.""" """Warn for every comfy* package whose installed version is below requirements.txt."""
from packaging.version import InvalidVersion, parse as parse_pep440 from packaging.version import InvalidVersion, parse as parse_pep440
outdated_packages = []
for pkg in get_comfy_package_versions(): for pkg in get_comfy_package_versions():
installed_str = pkg["installed"] installed_str = pkg["installed"]
required_str = pkg["required"] required_str = pkg["required"]
@ -73,19 +74,26 @@ def check_comfy_packages_versions():
logging.error(f"Failed to check {pkg['name']} version: {e}") logging.error(f"Failed to check {pkg['name']} version: {e}")
continue continue
if outdated: if outdated:
outdated_packages.append((pkg["name"], installed_str, required_str))
else:
logging.info("{} version: {}".format(pkg["name"], installed_str))
if outdated_packages:
package_warnings = "\n".join(
f"Installed {name} version {installed} is lower than the recommended version {required}."
for name, installed, required in outdated_packages
)
app.logger.log_startup_warning( app.logger.log_startup_warning(
f""" f"""
________________________________________________________________________ ________________________________________________________________________
WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING WARNING
Installed {pkg["name"]} version {installed_str} is lower than the recommended version {required_str}. {package_warnings}
{get_missing_requirements_message()} {get_missing_requirements_message()}
________________________________________________________________________ ________________________________________________________________________
""".strip() """.strip()
) )
else:
logging.info("{} version: {}".format(pkg["name"], installed_str))
REQUEST_TIMEOUT = 10 # seconds REQUEST_TIMEOUT = 10 # seconds

View File

@ -5,6 +5,40 @@ import logging
import sys import sys
import threading import threading
ANSI_NAMED_COLORS = {
'black': '\033[30m',
'red': '\033[31m',
'green': '\033[32m',
'yellow': '\033[33m',
'blue': '\033[34m',
'magenta': '\033[35m',
'cyan': '\033[36m',
'white': '\033[37m',
}
ANSI_LEVEL_COLORS = {
'DEBUG': ANSI_NAMED_COLORS['cyan'],
'INFO': ANSI_NAMED_COLORS['green'],
'WARNING': ANSI_NAMED_COLORS['yellow'],
'ERROR': ANSI_NAMED_COLORS['red'],
'CRITICAL': ANSI_NAMED_COLORS['magenta'],
}
ANSI_RESET = '\033[0m'
ANSI_BOLD = '\033[1m'
class ColoredFormatter(logging.Formatter):
def format(self, record):
color = ANSI_LEVEL_COLORS.get(record.levelname, '')
bold = ANSI_BOLD if record.levelno >= logging.WARNING else ''
level_tag = f"{bold}{color}[{record.levelname}]{ANSI_RESET} "
message = super().format(record)
line_color = ANSI_NAMED_COLORS.get(getattr(record, 'color', ''), '')
if line_color:
return f"{level_tag}{line_color}{message}{ANSI_RESET}"
return level_tag + message
logs = None logs = None
stdout_interceptor = None stdout_interceptor = None
stderr_interceptor = None stderr_interceptor = None
@ -68,8 +102,10 @@ def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool
logger = logging.getLogger() logger = logging.getLogger()
logger.setLevel(log_level) logger.setLevel(log_level)
formatter = ColoredFormatter("%(message)s")
stream_handler = logging.StreamHandler() stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter("%(message)s")) stream_handler.setFormatter(formatter)
if use_stdout: if use_stdout:
# Only errors and critical to stderr # Only errors and critical to stderr
@ -77,7 +113,7 @@ def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool
# Lesser to stdout # Lesser to stdout
stdout_handler = logging.StreamHandler(sys.stdout) stdout_handler = logging.StreamHandler(sys.stdout)
stdout_handler.setFormatter(logging.Formatter("%(message)s")) stdout_handler.setFormatter(formatter)
stdout_handler.addFilter(lambda record: record.levelno < logging.ERROR) stdout_handler.addFilter(lambda record: record.levelno < logging.ERROR)
logger.addHandler(stdout_handler) logger.addHandler(stdout_handler)

View File

@ -1,5 +1,3 @@
from __future__ import annotations
import os import os
import base64 import base64
import json import json

View File

@ -1,4 +1,3 @@
from __future__ import annotations
import json import json
import os import os
import re import re

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -1553,7 +1553,7 @@
"VHS_MetadataImage": true, "VHS_MetadataImage": true,
"VHS_KeepIntermediate": true "VHS_KeepIntermediate": true
}, },
"category": "Image generation and editing/Canny to image", "category": "Image generation and editing/Conditioned",
"description": "Generates an image from a Canny edge map using Z-Image-Turbo, with text conditioning." "description": "Generates an image from a Canny edge map using Z-Image-Turbo, with text conditioning."
} }
] ]

View File

@ -3600,7 +3600,7 @@
"extra": { "extra": {
"workflowRendererVersion": "LG" "workflowRendererVersion": "LG"
}, },
"category": "Video generation and editing/Canny to video", "category": "Video generation and editing/Conditioned",
"description": "Generates video from Canny edge maps using LTX-2, with optional synchronized audio." "description": "Generates video from Canny edge maps using LTX-2, with optional synchronized audio."
} }
] ]

View File

@ -1401,7 +1401,7 @@
"extra": { "extra": {
"workflowRendererVersion": "LG" "workflowRendererVersion": "LG"
}, },
"category": "Image generation and editing/ControlNet", "category": "Image generation and editing/Conditioned",
"description": "Generates images from a text prompt and ControlNet conditioning (e.g. depth, canny) using Z-Image-Turbo." "description": "Generates images from a text prompt and ControlNet conditioning (e.g. depth, canny) using Z-Image-Turbo."
} }
] ]

View File

@ -1579,7 +1579,7 @@
"VHS_MetadataImage": true, "VHS_MetadataImage": true,
"VHS_KeepIntermediate": true "VHS_KeepIntermediate": true
}, },
"category": "Image generation and editing/Depth to image", "category": "Image generation and editing/Conditioned",
"description": "Generates an image from a depth map using Z-Image-Turbo with text conditioning." "description": "Generates an image from a depth map using Z-Image-Turbo with text conditioning."
}, },
{ {

View File

@ -4233,7 +4233,7 @@
"extra": { "extra": {
"workflowRendererVersion": "LG" "workflowRendererVersion": "LG"
}, },
"category": "Video generation and editing/Depth to video", "category": "Video generation and editing/Conditioned",
"description": "Generates depth-controlled video with LTX-2: motion and structure follow a depth-reference video alongside text prompting, optional first-frame image conditioning, with optional synchronized audio." "description": "Generates depth-controlled video with LTX-2: motion and structure follow a depth-reference video alongside text prompting, optional first-frame image conditioning, with optional synchronized audio."
}, },
{ {

View File

@ -3350,7 +3350,7 @@
} }
], ],
"extra": {}, "extra": {},
"category": "Video generation and editing/First-Last-Frame to Video", "category": "Video generation and editing/Conditioned",
"description": "Generates a video interpolating between first and last keyframes using LTX-2.3." "description": "Generates a video interpolating between first and last keyframes using LTX-2.3."
} }
] ]

View File

@ -3350,7 +3350,7 @@
} }
], ],
"extra": {}, "extra": {},
"category": "Video generation and editing/First-Last-Frame to Video", "category": "Video generation and editing/FLF2V",
"description": "Generates a video that interpolates between the first and last keyframes using LTX-2.3, including optional audio." "description": "Generates a video that interpolates between the first and last keyframes using LTX-2.3, including optional audio."
} }
] ]

File diff suppressed because it is too large Load Diff

View File

@ -310,7 +310,7 @@
"extra": { "extra": {
"workflowRendererVersion": "LG" "workflowRendererVersion": "LG"
}, },
"category": "Text generation/Image Captioning", "category": "Image Tools",
"description": "Generates descriptive captions for images using Google's Gemini multimodal LLM." "description": "Generates descriptive captions for images using Google's Gemini multimodal LLM."
} }
] ]

View File

@ -1,19 +1,18 @@
{ {
"id": "6af0a6c1-0161-4528-8685-65776e838d44",
"revision": 0, "revision": 0,
"last_node_id": 75, "last_node_id": 76,
"last_link_id": 245, "last_link_id": 0,
"nodes": [ "nodes": [
{ {
"id": 75, "id": 76,
"type": "488652fd-6edf-4d06-8f9f-4d84d3a34eaf", "type": "96338968-1242-4f02-b6a1-d496af4bcffe",
"pos": [ "pos": [
600, 670,
830 1280
], ],
"size": [ "size": [
400, 400,
110 201.3125
], ],
"flags": {}, "flags": {},
"order": 0, "order": 0,
@ -59,47 +58,44 @@
"links": [] "links": []
} }
], ],
"title": "Image Depth Estimation (Lotus Depth)",
"properties": { "properties": {
"proxyWidgets": [ "proxyWidgets": [
[ [
"-1", "28",
"sigma" "sigma"
], ],
[ [
"-1", "10",
"unet_name" "unet_name"
], ],
[ [
"-1", "14",
"vae_name" "vae_name"
] ]
], ],
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.14.1" "ver": "0.14.1"
}, },
"widgets_values": [ "widgets_values": []
999.0000000000002,
"lotus-depth-d-v1-1.safetensors",
"vae-ft-mse-840000-ema-pruned.safetensors"
]
} }
], ],
"links": [], "links": [],
"groups": [], "version": 0.4,
"definitions": { "definitions": {
"subgraphs": [ "subgraphs": [
{ {
"id": "488652fd-6edf-4d06-8f9f-4d84d3a34eaf", "id": "96338968-1242-4f02-b6a1-d496af4bcffe",
"version": 1, "version": 1,
"state": { "state": {
"lastGroupId": 1, "lastGroupId": 1,
"lastNodeId": 75, "lastNodeId": 76,
"lastLinkId": 245, "lastLinkId": 245,
"lastRerouteId": 0 "lastRerouteId": 0
}, },
"revision": 0, "revision": 0,
"config": {}, "config": {},
"name": "Image to Depth Map (Lotus)", "name": "Image Depth Estimation (Lotus Depth)",
"inputNode": { "inputNode": {
"id": -10, "id": -10,
"bounding": [ "bounding": [
@ -191,12 +187,12 @@
"id": 10, "id": 10,
"type": "UNETLoader", "type": "UNETLoader",
"pos": [ "pos": [
108.05555555555557, 110,
-253.05555555555557 -250
], ],
"size": [ "size": [
254.93706597222226, 260,
82 90
], ],
"flags": {}, "flags": {},
"order": 4, "order": 4,
@ -234,9 +230,9 @@
} }
], ],
"properties": { "properties": {
"Node name for S&R": "UNETLoader",
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.3.34", "ver": "0.3.34",
"Node name for S&R": "UNETLoader",
"models": [ "models": [
{ {
"name": "lotus-depth-d-v1-1.safetensors", "name": "lotus-depth-d-v1-1.safetensors",
@ -255,12 +251,12 @@
"id": 18, "id": 18,
"type": "DisableNoise", "type": "DisableNoise",
"pos": [ "pos": [
607.0641494069639, 610,
-268.33337840371513 -270
], ],
"size": [ "size": [
175, 180,
33.333333333333336 40
], ],
"flags": {}, "flags": {},
"order": 0, "order": 0,
@ -278,26 +274,25 @@
} }
], ],
"properties": { "properties": {
"Node name for S&R": "DisableNoise",
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.3.34", "ver": "0.3.34",
"Node name for S&R": "DisableNoise",
"widget_ue_connectable": {} "widget_ue_connectable": {}
}, }
"widgets_values": []
}, },
{ {
"id": 23, "id": 74,
"type": "VAEEncode", "type": "VAEEncode",
"pos": [ "pos": [
620, 620,
160 160
], ],
"size": [ "size": [
175, 180,
50 50
], ],
"flags": {}, "flags": {},
"order": 10, "order": 11,
"mode": 0, "mode": 0,
"inputs": [ "inputs": [
{ {
@ -325,12 +320,11 @@
} }
], ],
"properties": { "properties": {
"Node name for S&R": "VAEEncode",
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.3.34", "ver": "0.3.34",
"Node name for S&R": "VAEEncode",
"widget_ue_connectable": {} "widget_ue_connectable": {}
}, }
"widgets_values": []
}, },
{ {
"id": 21, "id": 21,
@ -341,7 +335,7 @@
], ],
"size": [ "size": [
210, 210,
58 60
], ],
"flags": {}, "flags": {},
"order": 1, "order": 1,
@ -369,9 +363,9 @@
} }
], ],
"properties": { "properties": {
"Node name for S&R": "KSamplerSelect",
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.3.34", "ver": "0.3.34",
"Node name for S&R": "KSamplerSelect",
"widget_ue_connectable": {} "widget_ue_connectable": {}
}, },
"widgets_values": [ "widgets_values": [
@ -386,7 +380,7 @@
-170 -170
], ],
"size": [ "size": [
175, 180,
50 50
], ],
"flags": {}, "flags": {},
@ -418,12 +412,11 @@
} }
], ],
"properties": { "properties": {
"Node name for S&R": "BasicGuider",
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.3.34", "ver": "0.3.34",
"Node name for S&R": "BasicGuider",
"widget_ue_connectable": {} "widget_ue_connectable": {}
}, }
"widgets_values": []
}, },
{ {
"id": 16, "id": 16,
@ -433,8 +426,8 @@
-130 -130
], ],
"size": [ "size": [
295.99609375, 300,
271.65798611111114 280
], ],
"flags": {}, "flags": {},
"order": 6, "order": 6,
@ -490,12 +483,11 @@
} }
], ],
"properties": { "properties": {
"Node name for S&R": "SamplerCustomAdvanced",
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.3.34", "ver": "0.3.34",
"Node name for S&R": "SamplerCustomAdvanced",
"widget_ue_connectable": {} "widget_ue_connectable": {}
}, }
"widgets_values": []
}, },
{ {
"id": 28, "id": 28,
@ -506,10 +498,10 @@
], ],
"size": [ "size": [
210, 210,
58 60
], ],
"flags": {}, "flags": {},
"order": 11, "order": 10,
"mode": 0, "mode": 0,
"inputs": [ "inputs": [
{ {
@ -540,9 +532,9 @@
} }
], ],
"properties": { "properties": {
"Node name for S&R": "SetFirstSigma",
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.3.34", "ver": "0.3.34",
"Node name for S&R": "SetFirstSigma",
"widget_ue_connectable": {} "widget_ue_connectable": {}
}, },
"widgets_values": [ "widgets_values": [
@ -557,7 +549,7 @@
-120 -120
], ],
"size": [ "size": [
175, 180,
50 50
], ],
"flags": {}, "flags": {},
@ -589,12 +581,11 @@
} }
], ],
"properties": { "properties": {
"Node name for S&R": "VAEDecode",
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.3.34", "ver": "0.3.34",
"Node name for S&R": "VAEDecode",
"widget_ue_connectable": {} "widget_ue_connectable": {}
}, }
"widgets_values": []
}, },
{ {
"id": 22, "id": 22,
@ -604,8 +595,8 @@
-220 -220
], ],
"size": [ "size": [
175, 180,
33.333333333333336 40
], ],
"flags": {}, "flags": {},
"order": 9, "order": 9,
@ -630,12 +621,11 @@
} }
], ],
"properties": { "properties": {
"Node name for S&R": "ImageInvert",
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.3.34", "ver": "0.3.34",
"Node name for S&R": "ImageInvert",
"widget_ue_connectable": {} "widget_ue_connectable": {}
}, }
"widgets_values": []
}, },
{ {
"id": 14, "id": 14,
@ -645,8 +635,8 @@
-90 -90
], ],
"size": [ "size": [
254.93706597222226, 260,
58 60
], ],
"flags": {}, "flags": {},
"order": 5, "order": 5,
@ -675,9 +665,9 @@
} }
], ],
"properties": { "properties": {
"Node name for S&R": "VAELoader",
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.3.34", "ver": "0.3.34",
"Node name for S&R": "VAELoader",
"models": [ "models": [
{ {
"name": "vae-ft-mse-840000-ema-pruned.safetensors", "name": "vae-ft-mse-840000-ema-pruned.safetensors",
@ -692,15 +682,15 @@
] ]
}, },
{ {
"id": 68, "id": 75,
"type": "LotusConditioning", "type": "LotusConditioning",
"pos": [ "pos": [
400, 400,
-150 -150
], ],
"size": [ "size": [
175, 180,
33.333333333333336 40
], ],
"flags": {}, "flags": {},
"order": 2, "order": 2,
@ -718,12 +708,11 @@
} }
], ],
"properties": { "properties": {
"Node name for S&R": "LotusConditioning",
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.3.34", "ver": "0.3.34",
"Node name for S&R": "LotusConditioning",
"widget_ue_connectable": {} "widget_ue_connectable": {}
}, }
"widgets_values": []
}, },
{ {
"id": 20, "id": 20,
@ -734,7 +723,7 @@
], ],
"size": [ "size": [
210, 210,
106 110
], ],
"flags": {}, "flags": {},
"order": 8, "order": 8,
@ -786,9 +775,9 @@
} }
], ],
"properties": { "properties": {
"Node name for S&R": "BasicScheduler",
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.3.34", "ver": "0.3.34",
"Node name for S&R": "BasicScheduler",
"widget_ue_connectable": {} "widget_ue_connectable": {}
}, },
"widgets_values": [ "widgets_values": [
@ -850,7 +839,7 @@
}, },
{ {
"id": 201, "id": 201,
"origin_id": 23, "origin_id": 74,
"origin_slot": 0, "origin_slot": 0,
"target_id": 16, "target_id": 16,
"target_slot": 4, "target_slot": 4,
@ -866,7 +855,7 @@
}, },
{ {
"id": 238, "id": 238,
"origin_id": 68, "origin_id": 75,
"origin_slot": 0, "origin_slot": 0,
"target_id": 19, "target_id": 19,
"target_slot": 1, "target_slot": 1,
@ -892,7 +881,7 @@
"id": 38, "id": 38,
"origin_id": 14, "origin_id": 14,
"origin_slot": 0, "origin_slot": 0,
"target_id": 23, "target_id": 74,
"target_slot": 1, "target_slot": 1,
"type": "VAE" "type": "VAE"
}, },
@ -908,7 +897,7 @@
"id": 37, "id": 37,
"origin_id": -10, "origin_id": -10,
"origin_slot": 0, "origin_slot": 0,
"target_id": 23, "target_id": 74,
"target_slot": 0, "target_slot": 0,
"type": "IMAGE" "type": "IMAGE"
}, },
@ -948,12 +937,11 @@
"extra": { "extra": {
"workflowRendererVersion": "LG" "workflowRendererVersion": "LG"
}, },
"category": "Image generation and editing/Depth to image", "category": "Conditioning & Preprocessors/Depth",
"description": "Estimates a monocular depth map from an input image using the Lotus depth estimation model." "description": "Estimates a monocular depth map from an input image using the Lotus depth estimation model."
} }
] ]
}, },
"config": {},
"extra": { "extra": {
"ds": { "ds": {
"scale": 1.3589709866044692, "scale": 1.3589709866044692,
@ -961,8 +949,6 @@
-138.53613935617864, -138.53613935617864,
-786.0629126022195 -786.0629126022195
] ]
}, }
"workflowRendererVersion": "LG" }
},
"version": 0.4
} }

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,779 @@
{
"revision": 0,
"last_node_id": 33,
"last_link_id": 0,
"nodes": [
{
"id": 33,
"type": "6062babb-b649-4a71-be9e-20ebce567744",
"pos": [
-450,
4240
],
"size": [
420,
400
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [
{
"localized_name": "image",
"name": "image",
"type": "IMAGE",
"link": null
},
{
"name": "face_landmarker",
"type": "FACE_LANDMARKER",
"link": null
},
{
"name": "detector_variant",
"type": "COMBO",
"widget": {
"name": "detector_variant"
},
"link": null
},
{
"name": "num_faces",
"type": "INT",
"widget": {
"name": "num_faces"
},
"link": null
},
{
"label": "custom_face_oval",
"name": "regions.face_oval",
"type": "BOOLEAN",
"widget": {
"name": "regions.face_oval"
},
"link": null
},
{
"label": "custom_lips",
"name": "regions.lips",
"type": "BOOLEAN",
"widget": {
"name": "regions.lips"
},
"link": null
},
{
"label": "custom_left_eye",
"name": "regions.left_eye",
"type": "BOOLEAN",
"widget": {
"name": "regions.left_eye"
},
"link": null
},
{
"label": "custom_right_eye",
"name": "regions.right_eye",
"type": "BOOLEAN",
"widget": {
"name": "regions.right_eye"
},
"link": null
},
{
"label": "custom_irises",
"name": "regions.irises",
"type": "BOOLEAN",
"widget": {
"name": "regions.irises"
},
"link": null
},
{
"name": "model_name",
"type": "COMBO",
"widget": {
"name": "model_name"
},
"link": null
}
],
"outputs": [
{
"localized_name": "face_landmarks",
"name": "face_landmarks",
"type": "FACE_LANDMARKS",
"links": []
},
{
"localized_name": "bboxes",
"name": "bboxes",
"type": "BOUNDING_BOX",
"links": []
},
{
"label": "mask",
"name": "MASK_1",
"type": "MASK",
"links": []
}
],
"title": "Image Face Detection (Mediapipe)",
"properties": {
"proxyWidgets": [
[
"11",
"detector_variant"
],
[
"11",
"num_faces"
],
[
"20",
"regions.face_oval"
],
[
"20",
"regions.lips"
],
[
"20",
"regions.left_eye"
],
[
"20",
"regions.right_eye"
],
[
"20",
"regions.irises"
],
[
"2",
"model_name"
]
],
"cnr_id": "comfy-core",
"ver": "0.22.0",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65
},
"widgets_values": []
}
],
"links": [],
"version": 0.4,
"definitions": {
"subgraphs": [
{
"id": "6062babb-b649-4a71-be9e-20ebce567744",
"version": 1,
"state": {
"lastGroupId": 2,
"lastNodeId": 158,
"lastLinkId": 140,
"lastRerouteId": 0
},
"revision": 0,
"config": {},
"name": "Image Face Detection (Mediapipe)",
"description": "Detects facial landmarks from an image using MediaPipe, outputting landmark data, face bounding boxes, and an optional face-region mask.",
"inputNode": {
"id": -10,
"bounding": [
-710,
4300,
148.880859375,
248
]
},
"outputNode": {
"id": -20,
"bounding": [
140,
4480,
137.677734375,
108
]
},
"inputs": [
{
"id": "705dc1ae-6dc9-4155-92df-52f816ad451e",
"name": "image",
"type": "IMAGE",
"linkIds": [
60
],
"localized_name": "image",
"pos": [
-585.119140625,
4324
]
},
{
"id": "d6277190-732c-4604-b7cd-d3a9588bf761",
"name": "face_landmarker",
"type": "FACE_LANDMARKER",
"linkIds": [
74
],
"pos": [
-585.119140625,
4344
]
},
{
"id": "ac473a08-6a86-42a7-b460-e70c6c5e1e2b",
"name": "detector_variant",
"type": "COMBO",
"linkIds": [
75
],
"pos": [
-585.119140625,
4364
]
},
{
"id": "1bec2252-ca2d-496e-8a33-33a61d21f897",
"name": "num_faces",
"type": "INT",
"linkIds": [
76
],
"pos": [
-585.119140625,
4384
]
},
{
"id": "17994fa2-0ea0-4c9b-a70a-19789c459c80",
"name": "regions.face_oval",
"type": "BOOLEAN",
"linkIds": [
77
],
"label": "custom_face_oval",
"pos": [
-585.119140625,
4404
]
},
{
"id": "1c6c5893-2aee-4c37-b702-15ef2e20d863",
"name": "regions.lips",
"type": "BOOLEAN",
"linkIds": [
78
],
"label": "custom_lips",
"pos": [
-585.119140625,
4424
]
},
{
"id": "f353fcea-4b6f-42a1-8fdd-32b3aa1e1f09",
"name": "regions.left_eye",
"type": "BOOLEAN",
"linkIds": [
79
],
"label": "custom_left_eye",
"pos": [
-585.119140625,
4444
]
},
{
"id": "1387e121-c1fb-4522-8f0d-43459e11dd86",
"name": "regions.right_eye",
"type": "BOOLEAN",
"linkIds": [
80
],
"label": "custom_right_eye",
"pos": [
-585.119140625,
4464
]
},
{
"id": "14acb0a0-d1f4-48f3-ba31-811b26236ef9",
"name": "regions.irises",
"type": "BOOLEAN",
"linkIds": [
81
],
"label": "custom_irises",
"pos": [
-585.119140625,
4484
]
},
{
"id": "25a82859-87de-42c8-8431-09948665546e",
"name": "model_name",
"type": "COMBO",
"linkIds": [
86
],
"pos": [
-585.119140625,
4504
]
}
],
"outputs": [
{
"id": "d2ba3f92-e8b1-49c3-9590-cfad56c54cf4",
"name": "face_landmarks",
"type": "FACE_LANDMARKS",
"linkIds": [
44
],
"localized_name": "face_landmarks",
"pos": [
164,
4504
]
},
{
"id": "4f356bb0-d4c4-4f93-b4cf-0845a65c4e6d",
"name": "bboxes",
"type": "BOUNDING_BOX",
"linkIds": [
25
],
"localized_name": "bboxes",
"pos": [
164,
4524
]
},
{
"id": "f6309e1d-6397-4363-b38f-778a122abc51",
"name": "MASK_1",
"type": "MASK",
"linkIds": [
83
],
"label": "mask",
"pos": [
164,
4544
]
}
],
"widgets": [],
"nodes": [
{
"id": 11,
"type": "MediaPipeFaceLandmarker",
"pos": [
-280,
4280
],
"size": [
350,
220
],
"flags": {},
"order": 1,
"mode": 0,
"inputs": [
{
"localized_name": "face_detection_model",
"name": "face_detection_model",
"type": "FACE_DETECTION_MODEL",
"link": 66
},
{
"localized_name": "image",
"name": "image",
"type": "IMAGE",
"link": 60
},
{
"localized_name": "detector_variant",
"name": "detector_variant",
"type": "COMBO",
"widget": {
"name": "detector_variant"
},
"link": 75
},
{
"localized_name": "num_faces",
"name": "num_faces",
"type": "INT",
"widget": {
"name": "num_faces"
},
"link": 76
},
{
"localized_name": "min_confidence",
"name": "min_confidence",
"type": "FLOAT",
"widget": {
"name": "min_confidence"
},
"link": null
},
{
"localized_name": "missing_frame_fallback",
"name": "missing_frame_fallback",
"type": "COMBO",
"widget": {
"name": "missing_frame_fallback"
},
"link": null
},
{
"name": "face_landmarker",
"type": "FACE_LANDMARKER",
"link": 74
}
],
"outputs": [
{
"localized_name": "face_landmarks",
"name": "face_landmarks",
"type": "FACE_LANDMARKS",
"links": [
44,
46
]
},
{
"localized_name": "bboxes",
"name": "bboxes",
"type": "BOUNDING_BOX",
"links": [
25
]
}
],
"properties": {
"Node name for S&R": "MediaPipeFaceLandmarker",
"cnr_id": "comfy-core",
"ver": "0.22.0",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65
},
"widgets_values": [
"full",
0,
0.5,
"empty"
]
},
{
"id": 2,
"type": "LoadMediaPipeFaceLandmarker",
"pos": [
-290,
4060
],
"size": [
350,
140
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [
{
"localized_name": "model_name",
"name": "model_name",
"type": "COMBO",
"widget": {
"name": "model_name"
},
"link": 86
}
],
"outputs": [
{
"localized_name": "FACE_DETECTION_MODEL",
"name": "FACE_DETECTION_MODEL",
"type": "FACE_DETECTION_MODEL",
"links": [
66
]
}
],
"properties": {
"Node name for S&R": "LoadMediaPipeFaceLandmarker",
"cnr_id": "comfy-core",
"ver": "0.22.0",
"models": [
{
"name": "mediapipe_face_fp32.safetensors",
"url": "https://huggingface.co/Comfy-Org/mediapipe/resolve/main/detection/mediapipe_face_fp32.safetensors",
"directory": "detection"
}
],
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65
},
"widgets_values": [
"mediapipe_face_fp32.safetensors"
]
},
{
"id": 20,
"type": "MediaPipeFaceMask",
"pos": [
-290,
4560
],
"size": [
360,
180
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [
{
"localized_name": "face_landmarks",
"name": "face_landmarks",
"type": "FACE_LANDMARKS",
"link": 46
},
{
"localized_name": "regions",
"name": "regions",
"type": "COMFY_DYNAMICCOMBO_V3",
"widget": {
"name": "regions"
},
"link": null
},
{
"localized_name": "regions.face_oval",
"name": "regions.face_oval",
"type": "BOOLEAN",
"widget": {
"name": "regions.face_oval"
},
"link": 77
},
{
"localized_name": "regions.lips",
"name": "regions.lips",
"type": "BOOLEAN",
"widget": {
"name": "regions.lips"
},
"link": 78
},
{
"localized_name": "regions.left_eye",
"name": "regions.left_eye",
"type": "BOOLEAN",
"widget": {
"name": "regions.left_eye"
},
"link": 79
},
{
"localized_name": "regions.right_eye",
"name": "regions.right_eye",
"type": "BOOLEAN",
"widget": {
"name": "regions.right_eye"
},
"link": 80
},
{
"localized_name": "regions.irises",
"name": "regions.irises",
"type": "BOOLEAN",
"widget": {
"name": "regions.irises"
},
"link": 81
}
],
"outputs": [
{
"localized_name": "MASK",
"name": "MASK",
"type": "MASK",
"links": [
83
]
}
],
"properties": {
"Node name for S&R": "MediaPipeFaceMask",
"cnr_id": "comfy-core",
"ver": "0.22.0",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65
},
"widgets_values": [
"custom",
true,
false,
false,
false,
false
]
}
],
"groups": [],
"links": [
{
"id": 66,
"origin_id": 2,
"origin_slot": 0,
"target_id": 11,
"target_slot": 0,
"type": "FACE_DETECTION_MODEL"
},
{
"id": 46,
"origin_id": 11,
"origin_slot": 0,
"target_id": 20,
"target_slot": 0,
"type": "FACE_LANDMARKS"
},
{
"id": 60,
"origin_id": -10,
"origin_slot": 0,
"target_id": 11,
"target_slot": 1,
"type": "IMAGE"
},
{
"id": 44,
"origin_id": 11,
"origin_slot": 0,
"target_id": -20,
"target_slot": 0,
"type": "FACE_LANDMARKS"
},
{
"id": 25,
"origin_id": 11,
"origin_slot": 1,
"target_id": -20,
"target_slot": 1,
"type": "BOUNDING_BOX"
},
{
"id": 74,
"origin_id": -10,
"origin_slot": 1,
"target_id": 11,
"target_slot": 6,
"type": "FACE_LANDMARKER"
},
{
"id": 75,
"origin_id": -10,
"origin_slot": 2,
"target_id": 11,
"target_slot": 2,
"type": "COMBO"
},
{
"id": 76,
"origin_id": -10,
"origin_slot": 3,
"target_id": 11,
"target_slot": 3,
"type": "INT"
},
{
"id": 77,
"origin_id": -10,
"origin_slot": 4,
"target_id": 20,
"target_slot": 2,
"type": "BOOLEAN"
},
{
"id": 78,
"origin_id": -10,
"origin_slot": 5,
"target_id": 20,
"target_slot": 3,
"type": "BOOLEAN"
},
{
"id": 79,
"origin_id": -10,
"origin_slot": 6,
"target_id": 20,
"target_slot": 4,
"type": "BOOLEAN"
},
{
"id": 80,
"origin_id": -10,
"origin_slot": 7,
"target_id": 20,
"target_slot": 5,
"type": "BOOLEAN"
},
{
"id": 81,
"origin_id": -10,
"origin_slot": 8,
"target_id": 20,
"target_slot": 6,
"type": "BOOLEAN"
},
{
"id": 83,
"origin_id": 20,
"origin_slot": 0,
"target_id": -20,
"target_slot": 2,
"type": "MASK"
},
{
"id": 86,
"origin_id": -10,
"origin_slot": 9,
"target_id": 2,
"target_slot": 0,
"type": "COMBO"
}
],
"extra": {},
"category": "Conditioning & Preprocessors/Face Detection"
}
]
},
"extra": {}
}

View File

@ -703,7 +703,7 @@
} }
], ],
"extra": {}, "extra": {},
"category": "Image Tools/Image Segmentation", "category": "Conditioning & Preprocessors/Segmentation & Mask",
"description": "Segments images into masks using Meta SAM3 from text prompts, points, or boxes." "description": "Segments images into masks using Meta SAM3 from text prompts, points, or boxes."
} }
] ]

View File

@ -1302,7 +1302,7 @@
"extra": { "extra": {
"workflowRendererVersion": "LG" "workflowRendererVersion": "LG"
}, },
"category": "Image generation and editing/Enhance", "category": "Image generation and editing/Upscale",
"description": "Upscales images to higher resolution using Z-Image-Turbo." "description": "Upscales images to higher resolution using Z-Image-Turbo."
} }
] ]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,888 @@
{
"revision": 0,
"last_node_id": 675,
"last_link_id": 0,
"nodes": [
{
"id": 675,
"type": "01b6a731-fb78-4070-9a38-c87146da9604",
"pos": [
-2480,
3400
],
"size": [
360,
433.3125
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [
{
"localized_name": "input",
"name": "input",
"type": "IMAGE,MASK",
"link": null
},
{
"label": "resize_target_longer_size",
"name": "resize_type.longer_size",
"type": "INT",
"widget": {
"name": "resize_type.longer_size"
},
"link": null
},
{
"name": "scale_method",
"type": "COMBO",
"widget": {
"name": "scale_method"
},
"link": null
},
{
"name": "draw_body",
"type": "BOOLEAN",
"widget": {
"name": "draw_body"
},
"link": null
},
{
"name": "draw_hands",
"type": "BOOLEAN",
"widget": {
"name": "draw_hands"
},
"link": null
},
{
"name": "draw_face",
"type": "BOOLEAN",
"widget": {
"name": "draw_face"
},
"link": null
},
{
"name": "draw_feet",
"type": "BOOLEAN",
"widget": {
"name": "draw_feet"
},
"link": null
},
{
"name": "stick_width",
"type": "INT",
"widget": {
"name": "stick_width"
},
"link": null
},
{
"name": "face_point_size",
"type": "INT",
"widget": {
"name": "face_point_size"
},
"link": null
},
{
"name": "score_threshold",
"type": "FLOAT",
"widget": {
"name": "score_threshold"
},
"link": null
},
{
"name": "ckpt_name",
"type": "COMBO",
"widget": {
"name": "ckpt_name"
},
"link": null
},
{
"name": "bboxes",
"shape": 7,
"type": "BOUNDING_BOX",
"link": null
}
],
"outputs": [
{
"localized_name": "IMAGE",
"name": "IMAGE",
"type": "IMAGE",
"links": []
},
{
"name": "keypoints",
"type": "POSE_KEYPOINT",
"links": null
}
],
"properties": {
"proxyWidgets": [
[
"674",
"resize_type.longer_size"
],
[
"674",
"scale_method"
],
[
"672",
"draw_body"
],
[
"672",
"draw_hands"
],
[
"672",
"draw_face"
],
[
"672",
"draw_feet"
],
[
"672",
"stick_width"
],
[
"672",
"face_point_size"
],
[
"672",
"score_threshold"
],
[
"673",
"ckpt_name"
]
],
"cnr_id": "comfy-core",
"ver": "0.15.1",
"ue_properties": {
"widget_ue_connectable": {},
"version": "7.7",
"input_ue_unconnectable": {}
}
},
"widgets_values": [],
"title": "Image to Pose Map (SDPose-OOD)"
}
],
"links": [],
"version": 0.4,
"definitions": {
"subgraphs": [
{
"id": "01b6a731-fb78-4070-9a38-c87146da9604",
"version": 1,
"state": {
"lastGroupId": 0,
"lastNodeId": 676,
"lastLinkId": 1715,
"lastRerouteId": 0
},
"revision": 0,
"config": {},
"name": "Image to Pose Map (SDPose-OOD)",
"inputNode": {
"id": -10,
"bounding": [
-3290,
3590,
190.8984375,
288
]
},
"outputNode": {
"id": -20,
"bounding": [
-1756.2451602089645,
3366,
128,
88
]
},
"inputs": [
{
"id": "e24699c3-1356-4634-9eb4-19bb58e5c0b0",
"name": "input",
"type": "IMAGE,MASK",
"linkIds": [
1700
],
"localized_name": "input",
"pos": [
-3123.1015625,
3614
]
},
{
"id": "088eefc1-cd8a-4573-993f-9e4da008a12d",
"name": "resize_type.longer_size",
"type": "INT",
"linkIds": [
1704
],
"label": "resize_target_longer_size",
"pos": [
-3123.1015625,
3634
]
},
{
"id": "b6449bd3-73d4-41c8-b81f-cf8d33f76a2e",
"name": "scale_method",
"type": "COMBO",
"linkIds": [
1705
],
"pos": [
-3123.1015625,
3654
]
},
{
"id": "4cff52ad-ed07-4c97-8803-fcbd89554fd0",
"name": "draw_body",
"type": "BOOLEAN",
"linkIds": [
1706
],
"pos": [
-3123.1015625,
3674
]
},
{
"id": "7af63dce-f7df-4d7e-8215-d7c7f60bf81c",
"name": "draw_hands",
"type": "BOOLEAN",
"linkIds": [
1707
],
"pos": [
-3123.1015625,
3694
]
},
{
"id": "af3a9bce-61f9-4aca-b530-9f65e028b35e",
"name": "draw_face",
"type": "BOOLEAN",
"linkIds": [
1708
],
"pos": [
-3123.1015625,
3714
]
},
{
"id": "4620f6a3-2c85-4b79-ad8f-35d0326b568f",
"name": "draw_feet",
"type": "BOOLEAN",
"linkIds": [
1709
],
"pos": [
-3123.1015625,
3734
]
},
{
"id": "fee5d0c9-8d4b-4934-81d8-ba2206dc56cb",
"name": "stick_width",
"type": "INT",
"linkIds": [
1710
],
"pos": [
-3123.1015625,
3754
]
},
{
"id": "aafdd060-ba81-4324-a9cc-b656e1ebc133",
"name": "face_point_size",
"type": "INT",
"linkIds": [
1711
],
"pos": [
-3123.1015625,
3774
]
},
{
"id": "514c5503-f9e6-4d23-b1ae-1d3291acb2a3",
"name": "score_threshold",
"type": "FLOAT",
"linkIds": [
1712
],
"pos": [
-3123.1015625,
3794
]
},
{
"id": "ae46de61-2cc6-483e-8ee9-87e4144a2ffa",
"name": "ckpt_name",
"type": "COMBO",
"linkIds": [
1713
],
"pos": [
-3123.1015625,
3814
]
},
{
"id": "41bec0c6-dffa-4c78-9289-ee678715ae54",
"name": "bboxes",
"type": "BOUNDING_BOX",
"linkIds": [
1714
],
"pos": [
-3123.1015625,
3834
]
}
],
"outputs": [
{
"id": "f05ed8cc-9403-4f14-8085-4364b06f8a48",
"name": "IMAGE",
"type": "IMAGE",
"linkIds": [
1701
],
"localized_name": "IMAGE",
"pos": [
-1732.2451602089645,
3390
]
},
{
"id": "29a6584e-4685-4986-8ffd-e6d8539953fd",
"name": "keypoints",
"type": "POSE_KEYPOINT",
"linkIds": [
1715
],
"pos": [
-1732.2451602089645,
3410
]
}
],
"widgets": [],
"nodes": [
{
"id": 671,
"type": "SDPoseKeypointExtractor",
"pos": [
-2470,
3250
],
"size": [
270,
180
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [
{
"localized_name": "model",
"name": "model",
"type": "MODEL",
"link": 1696
},
{
"localized_name": "vae",
"name": "vae",
"type": "VAE",
"link": 1697
},
{
"localized_name": "image",
"name": "image",
"type": "IMAGE",
"link": 1698
},
{
"localized_name": "bboxes",
"name": "bboxes",
"shape": 7,
"type": "BOUNDING_BOX",
"link": 1714
},
{
"localized_name": "batch_size",
"name": "batch_size",
"type": "INT",
"widget": {
"name": "batch_size"
},
"link": null
}
],
"outputs": [
{
"localized_name": "keypoints",
"name": "keypoints",
"type": "POSE_KEYPOINT",
"links": [
1699,
1715
]
}
],
"properties": {
"Node name for S&R": "SDPoseKeypointExtractor",
"cnr_id": "comfy-core",
"ver": "0.15.0",
"ue_properties": {
"widget_ue_connectable": {},
"version": "7.7",
"input_ue_unconnectable": {}
}
},
"widgets_values": [
16
]
},
{
"id": 674,
"type": "ResizeImageMaskNode",
"pos": [
-2960,
3490
],
"size": [
270,
110
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [
{
"localized_name": "input",
"name": "input",
"type": "IMAGE,MASK",
"link": 1700
},
{
"localized_name": "resize_type",
"name": "resize_type",
"type": "COMFY_DYNAMICCOMBO_V3",
"widget": {
"name": "resize_type"
},
"link": null
},
{
"localized_name": "resize_type.longer_size",
"name": "resize_type.longer_size",
"type": "INT",
"widget": {
"name": "resize_type.longer_size"
},
"link": 1704
},
{
"localized_name": "scale_method",
"name": "scale_method",
"type": "COMBO",
"widget": {
"name": "scale_method"
},
"link": 1705
}
],
"outputs": [
{
"localized_name": "resized",
"name": "resized",
"type": "*",
"links": [
1698
]
}
],
"properties": {
"Node name for S&R": "ResizeImageMaskNode",
"cnr_id": "comfy-core",
"ver": "0.15.0",
"ue_properties": {
"widget_ue_connectable": {},
"version": "7.7",
"input_ue_unconnectable": {}
}
},
"widgets_values": [
"scale longer dimension",
1024,
"area"
]
},
{
"id": 672,
"type": "SDPoseDrawKeypoints",
"pos": [
-2120,
3260
],
"size": [
270,
280
],
"flags": {},
"order": 1,
"mode": 0,
"inputs": [
{
"localized_name": "keypoints",
"name": "keypoints",
"type": "POSE_KEYPOINT",
"link": 1699
},
{
"localized_name": "draw_body",
"name": "draw_body",
"type": "BOOLEAN",
"widget": {
"name": "draw_body"
},
"link": 1706
},
{
"localized_name": "draw_hands",
"name": "draw_hands",
"type": "BOOLEAN",
"widget": {
"name": "draw_hands"
},
"link": 1707
},
{
"localized_name": "draw_face",
"name": "draw_face",
"type": "BOOLEAN",
"widget": {
"name": "draw_face"
},
"link": 1708
},
{
"localized_name": "draw_feet",
"name": "draw_feet",
"type": "BOOLEAN",
"widget": {
"name": "draw_feet"
},
"link": 1709
},
{
"localized_name": "stick_width",
"name": "stick_width",
"type": "INT",
"widget": {
"name": "stick_width"
},
"link": 1710
},
{
"localized_name": "face_point_size",
"name": "face_point_size",
"type": "INT",
"widget": {
"name": "face_point_size"
},
"link": 1711
},
{
"localized_name": "score_threshold",
"name": "score_threshold",
"type": "FLOAT",
"widget": {
"name": "score_threshold"
},
"link": 1712
}
],
"outputs": [
{
"localized_name": "IMAGE",
"name": "IMAGE",
"type": "IMAGE",
"links": [
1701
]
}
],
"properties": {
"Node name for S&R": "SDPoseDrawKeypoints",
"cnr_id": "comfy-core",
"ver": "0.15.0",
"ue_properties": {
"widget_ue_connectable": {},
"version": "7.7",
"input_ue_unconnectable": {}
}
},
"widgets_values": [
true,
true,
true,
true,
4,
2,
0.5
]
},
{
"id": 673,
"type": "CheckpointLoaderSimple",
"pos": [
-2960,
3250
],
"size": [
390,
190
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [
{
"localized_name": "ckpt_name",
"name": "ckpt_name",
"type": "COMBO",
"widget": {
"name": "ckpt_name"
},
"link": 1713
}
],
"outputs": [
{
"localized_name": "MODEL",
"name": "MODEL",
"type": "MODEL",
"links": [
1696
]
},
{
"localized_name": "CLIP",
"name": "CLIP",
"type": "CLIP",
"links": []
},
{
"localized_name": "VAE",
"name": "VAE",
"type": "VAE",
"links": [
1697
]
}
],
"properties": {
"Node name for S&R": "CheckpointLoaderSimple",
"cnr_id": "comfy-core",
"ver": "0.15.0",
"models": [
{
"name": "sdpose_wholebody_fp16.safetensors",
"url": "https://huggingface.co/Comfy-Org/SDPose/resolve/main/checkpoints/sdpose_wholebody_fp16.safetensors",
"directory": "checkpoints"
}
],
"ue_properties": {
"widget_ue_connectable": {},
"version": "7.7",
"input_ue_unconnectable": {}
}
},
"widgets_values": [
"sdpose_wholebody_fp16.safetensors"
]
}
],
"groups": [],
"links": [
{
"id": 1696,
"origin_id": 673,
"origin_slot": 0,
"target_id": 671,
"target_slot": 0,
"type": "MODEL"
},
{
"id": 1697,
"origin_id": 673,
"origin_slot": 2,
"target_id": 671,
"target_slot": 1,
"type": "VAE"
},
{
"id": 1698,
"origin_id": 674,
"origin_slot": 0,
"target_id": 671,
"target_slot": 2,
"type": "IMAGE"
},
{
"id": 1699,
"origin_id": 671,
"origin_slot": 0,
"target_id": 672,
"target_slot": 0,
"type": "POSE_KEYPOINT"
},
{
"id": 1700,
"origin_id": -10,
"origin_slot": 0,
"target_id": 674,
"target_slot": 0,
"type": "IMAGE,MASK"
},
{
"id": 1701,
"origin_id": 672,
"origin_slot": 0,
"target_id": -20,
"target_slot": 0,
"type": "IMAGE"
},
{
"id": 1704,
"origin_id": -10,
"origin_slot": 1,
"target_id": 674,
"target_slot": 2,
"type": "INT"
},
{
"id": 1705,
"origin_id": -10,
"origin_slot": 2,
"target_id": 674,
"target_slot": 3,
"type": "COMBO"
},
{
"id": 1706,
"origin_id": -10,
"origin_slot": 3,
"target_id": 672,
"target_slot": 1,
"type": "BOOLEAN"
},
{
"id": 1707,
"origin_id": -10,
"origin_slot": 4,
"target_id": 672,
"target_slot": 2,
"type": "BOOLEAN"
},
{
"id": 1708,
"origin_id": -10,
"origin_slot": 5,
"target_id": 672,
"target_slot": 3,
"type": "BOOLEAN"
},
{
"id": 1709,
"origin_id": -10,
"origin_slot": 6,
"target_id": 672,
"target_slot": 4,
"type": "BOOLEAN"
},
{
"id": 1710,
"origin_id": -10,
"origin_slot": 7,
"target_id": 672,
"target_slot": 5,
"type": "INT"
},
{
"id": 1711,
"origin_id": -10,
"origin_slot": 8,
"target_id": 672,
"target_slot": 6,
"type": "INT"
},
{
"id": 1712,
"origin_id": -10,
"origin_slot": 9,
"target_id": 672,
"target_slot": 7,
"type": "FLOAT"
},
{
"id": 1713,
"origin_id": -10,
"origin_slot": 10,
"target_id": 673,
"target_slot": 0,
"type": "COMBO"
},
{
"id": 1714,
"origin_id": -10,
"origin_slot": 11,
"target_id": 671,
"target_slot": 3,
"type": "BOUNDING_BOX"
},
{
"id": 1715,
"origin_id": 671,
"origin_slot": 0,
"target_id": -20,
"target_slot": 1,
"type": "POSE_KEYPOINT"
}
],
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Conditioning & Preprocessors/Pose",
"description": "Extracts human pose keypoints and stick-figure visuals from an image using SDPose-OOD, with optional bounding-box input per subject."
}
]
},
"extra": {
"ue_links": []
}
}

1219
blueprints/Merge Videos.json Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1298,7 +1298,7 @@
"VHS_MetadataImage": true, "VHS_MetadataImage": true,
"VHS_KeepIntermediate": true "VHS_KeepIntermediate": true
}, },
"category": "Image generation and editing/Pose to image", "category": "Image generation and editing/Conditioned",
"description": "Generates an image from pose keypoints using Z-Image-Turbo with text conditioning." "description": "Generates an image from pose keypoints using Z-Image-Turbo with text conditioning."
} }
] ]

View File

@ -3870,7 +3870,7 @@
"extra": { "extra": {
"workflowRendererVersion": "LG" "workflowRendererVersion": "LG"
}, },
"category": "Video generation and editing/Pose to video", "category": "Video generation and editing/Conditioned",
"description": "Generates video from pose reference frames using LTX-2, with optional synchronized audio." "description": "Generates video from pose reference frames using LTX-2, with optional synchronized audio."
} }
] ]

View File

@ -270,7 +270,7 @@
"extra": { "extra": {
"workflowRendererVersion": "LG" "workflowRendererVersion": "LG"
}, },
"category": "Text generation/Prompt enhance", "category": "Text Tools",
"description": "Expands short text prompts into detailed descriptions using a text generation model for better generation quality." "description": "Expands short text prompts into detailed descriptions using a text generation model for better generation quality."
} }
] ]

View File

@ -389,7 +389,7 @@
} }
], ],
"extra": {}, "extra": {},
"category": "Image generation and editing/Background Removal" "category": "Image Tools/Background Removal"
} }
] ]
}, },

View File

@ -0,0 +1,485 @@
{
"revision": 0,
"last_node_id": 10,
"last_link_id": 0,
"nodes": [
{
"id": 10,
"type": "3fb7557a-470d-4983-9d8c-6d5caa9788f0",
"pos": [
-250,
8590
],
"size": [
280,
360
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [
{
"localized_name": "text_per_line",
"name": "text_per_line",
"type": "STRING",
"widget": {
"name": "text_per_line"
},
"link": null
},
{
"localized_name": "index",
"name": "index",
"type": "INT",
"widget": {
"name": "index"
},
"link": null
}
],
"outputs": [
{
"localized_name": "selected_line",
"name": "selected_line",
"type": "STRING",
"links": []
}
],
"properties": {
"proxyWidgets": [
[
"2",
"string"
],
[
"3",
"value"
]
],
"cnr_id": "comfy-core",
"ver": "0.19.0",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {}
}
},
"widgets_values": [],
"title": "Select Per-Line Text by Index"
}
],
"links": [],
"version": 0.4,
"definitions": {
"subgraphs": [
{
"id": "3fb7557a-470d-4983-9d8c-6d5caa9788f0",
"version": 1,
"state": {
"lastGroupId": 0,
"lastNodeId": 10,
"lastLinkId": 14,
"lastRerouteId": 0
},
"revision": 0,
"config": {},
"name": "Select Per-Line Text by Index",
"inputNode": {
"id": -10,
"bounding": [
-990,
8595,
128,
88
]
},
"outputNode": {
"id": -20,
"bounding": [
710,
8585,
128,
68
]
},
"inputs": [
{
"id": "75417d82-a934-4ac9-b667-d8dcd5a3bfb3",
"name": "text_per_line",
"type": "STRING",
"linkIds": [
13
],
"localized_name": "text_per_line",
"pos": [
-886,
8619
]
},
{
"id": "46e69a73-1804-4ca6-9175-31445bf0be96",
"name": "index",
"type": "INT",
"linkIds": [
14
],
"localized_name": "index",
"pos": [
-886,
8639
]
}
],
"outputs": [
{
"id": "e34e8ad1-84d2-4bd2-a460-eb7de6067c10",
"name": "selected_line",
"type": "STRING",
"linkIds": [
10
],
"localized_name": "selected_line",
"pos": [
734,
8609
]
}
],
"widgets": [],
"nodes": [
{
"id": 1,
"type": "PreviewAny",
"pos": [
-500,
8400
],
"size": [
230,
180
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [
{
"localized_name": "source",
"name": "source",
"type": "*",
"link": 1
}
],
"outputs": [
{
"localized_name": "STRING",
"name": "STRING",
"type": "STRING",
"links": [
6
]
}
],
"properties": {
"Node name for S&R": "PreviewAny",
"cnr_id": "comfy-core",
"ver": "0.19.0",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {}
}
},
"widgets_values": [
null,
null,
null
]
},
{
"id": 2,
"type": "RegexExtract",
"pos": [
-240,
8740
],
"size": [
470,
460
],
"flags": {},
"order": 1,
"mode": 0,
"showAdvanced": false,
"inputs": [
{
"localized_name": "string",
"name": "string",
"type": "STRING",
"widget": {
"name": "string"
},
"link": 13
},
{
"localized_name": "regex_pattern",
"name": "regex_pattern",
"type": "STRING",
"widget": {
"name": "regex_pattern"
},
"link": 9
},
{
"localized_name": "mode",
"name": "mode",
"type": "COMBO",
"widget": {
"name": "mode"
},
"link": null
},
{
"localized_name": "case_insensitive",
"name": "case_insensitive",
"type": "BOOLEAN",
"widget": {
"name": "case_insensitive"
},
"link": null
},
{
"localized_name": "multiline",
"name": "multiline",
"type": "BOOLEAN",
"widget": {
"name": "multiline"
},
"link": null
},
{
"localized_name": "dotall",
"name": "dotall",
"type": "BOOLEAN",
"widget": {
"name": "dotall"
},
"link": null
},
{
"localized_name": "group_index",
"name": "group_index",
"type": "INT",
"widget": {
"name": "group_index"
},
"link": null
}
],
"outputs": [
{
"localized_name": "STRING",
"name": "STRING",
"type": "STRING",
"links": [
10
]
}
],
"properties": {
"Node name for S&R": "RegexExtract",
"cnr_id": "comfy-core",
"ver": "0.19.0",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {}
}
},
"widgets_values": [
"",
"",
"First Group",
false,
false,
false,
1
]
},
{
"id": 3,
"type": "PrimitiveInt",
"pos": [
-810,
8400
],
"size": [
270,
110
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [
{
"localized_name": "value",
"name": "value",
"type": "INT",
"widget": {
"name": "value"
},
"link": 14
}
],
"outputs": [
{
"localized_name": "INT",
"name": "INT",
"type": "INT",
"links": [
1
]
}
],
"title": "Int (line index)",
"properties": {
"Node name for S&R": "Int (line index)",
"cnr_id": "comfy-core",
"ver": "0.19.0",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {}
}
},
"widgets_values": [
0,
"fixed"
]
},
{
"id": 8,
"type": "StringReplace",
"pos": [
-240,
8400
],
"size": [
400,
280
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [
{
"localized_name": "string",
"name": "string",
"type": "STRING",
"widget": {
"name": "string"
},
"link": null
},
{
"localized_name": "find",
"name": "find",
"type": "STRING",
"widget": {
"name": "find"
},
"link": null
},
{
"localized_name": "replace",
"name": "replace",
"type": "STRING",
"widget": {
"name": "replace"
},
"link": 6
}
],
"outputs": [
{
"localized_name": "STRING",
"name": "STRING",
"type": "STRING",
"links": [
9
]
}
],
"properties": {
"Node name for S&R": "StringReplace",
"cnr_id": "comfy-core",
"ver": "0.19.0",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {}
}
},
"widgets_values": [
"^(?:[^\\n]*\\n){index}([^\\n]*)(?:\\n|$)",
"index",
""
]
}
],
"groups": [],
"links": [
{
"id": 1,
"origin_id": 3,
"origin_slot": 0,
"target_id": 1,
"target_slot": 0,
"type": "INT"
},
{
"id": 9,
"origin_id": 8,
"origin_slot": 0,
"target_id": 2,
"target_slot": 1,
"type": "STRING"
},
{
"id": 6,
"origin_id": 1,
"origin_slot": 0,
"target_id": 8,
"target_slot": 2,
"type": "STRING"
},
{
"id": 10,
"origin_id": 2,
"origin_slot": 0,
"target_id": -20,
"target_slot": 0,
"type": "STRING"
},
{
"id": 13,
"origin_id": -10,
"origin_slot": 0,
"target_id": 2,
"target_slot": 0,
"type": "STRING"
},
{
"id": 14,
"origin_id": -10,
"origin_slot": 1,
"target_id": 3,
"target_slot": 0,
"type": "INT"
}
],
"extra": {},
"category": "Text Tools",
"description": "Selects one line from multiline text by zero-based index for batch or list-driven prompt workflows."
}
]
},
"extra": {
"ue_links": [],
"links_added_by_ue": []
}
}

View File

@ -0,0 +1,714 @@
{
"revision": 0,
"last_node_id": 251,
"last_link_id": 0,
"nodes": [
{
"id": 251,
"type": "609e1fd1-b731-4b78-89ac-d19b1156b025",
"pos": [
-1490,
130
],
"size": [
230,
164
],
"flags": {},
"order": 1,
"mode": 0,
"inputs": [
{
"localized_name": "source_image",
"name": "source_image",
"type": "IMAGE",
"link": null
},
{
"localized_name": "columns",
"name": "columns",
"type": "INT",
"widget": {
"name": "columns"
},
"link": null
},
{
"localized_name": "rows",
"name": "rows",
"type": "INT",
"widget": {
"name": "rows"
},
"link": null
}
],
"outputs": [
{
"localized_name": "tiles",
"name": "tiles",
"type": "IMAGE",
"links": []
}
],
"properties": {
"proxyWidgets": [
[
"228",
"value"
],
[
"252",
"value"
]
],
"cnr_id": "comfy-core",
"ver": "0.20.1",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65
},
"widgets_values": [],
"title": "Split Image Grid to Tiles"
}
],
"links": [],
"version": 0.4,
"definitions": {
"subgraphs": [
{
"id": "609e1fd1-b731-4b78-89ac-d19b1156b025",
"version": 1,
"state": {
"lastGroupId": 9,
"lastNodeId": 252,
"lastLinkId": 429,
"lastRerouteId": 0
},
"revision": 0,
"config": {},
"name": "Split Image Grid to Tiles",
"inputNode": {
"id": -10,
"bounding": [
-1690,
260,
128,
108
]
},
"outputNode": {
"id": -20,
"bounding": [
-510,
590,
128,
68
]
},
"inputs": [
{
"id": "866ac798-cfbc-450a-b755-e704f86404d9",
"name": "source_image",
"type": "IMAGE",
"linkIds": [
386,
389
],
"localized_name": "source_image",
"pos": [
-1586,
284
]
},
{
"id": "bc37b1f8-8ab2-4f19-bd00-75d4fbc4feb3",
"name": "columns",
"type": "INT",
"linkIds": [
427
],
"localized_name": "columns",
"pos": [
-1586,
304
]
},
{
"id": "d45915da-e848-43dd-9ccc-e3161e9c99d9",
"name": "rows",
"type": "INT",
"linkIds": [
428
],
"localized_name": "rows",
"pos": [
-1586,
324
]
}
],
"outputs": [
{
"id": "18bc780f-064b-4038-87c6-67dba71deb08",
"name": "tiles",
"type": "IMAGE",
"linkIds": [
394
],
"localized_name": "tiles",
"shape": 6,
"pos": [
-486,
614
]
}
],
"widgets": [],
"nodes": [
{
"id": 225,
"type": "SplitImageToTileList",
"pos": [
-1010,
620
],
"size": [
290,
170
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [
{
"localized_name": "image",
"name": "image",
"type": "IMAGE",
"link": 386
},
{
"localized_name": "tile_width",
"name": "tile_width",
"type": "INT",
"widget": {
"name": "tile_width"
},
"link": 403
},
{
"localized_name": "tile_height",
"name": "tile_height",
"type": "INT",
"widget": {
"name": "tile_height"
},
"link": 404
},
{
"localized_name": "overlap",
"name": "overlap",
"type": "INT",
"widget": {
"name": "overlap"
},
"link": null
}
],
"outputs": [
{
"localized_name": "IMAGE",
"name": "IMAGE",
"shape": 6,
"type": "IMAGE",
"links": [
394
]
}
],
"properties": {
"Node name for S&R": "SplitImageToTileList",
"cnr_id": "comfy-core",
"ver": "0.20.1",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65
},
"widgets_values": [
1024,
1024,
0
]
},
{
"id": 231,
"type": "ComfyMathExpression",
"pos": [
-1080,
330
],
"size": [
370,
190
],
"flags": {},
"order": 4,
"mode": 0,
"inputs": [
{
"label": "a",
"localized_name": "values.a",
"name": "values.a",
"type": "FLOAT,INT,BOOLEAN",
"link": 390
},
{
"label": "b",
"localized_name": "values.b",
"name": "values.b",
"shape": 7,
"type": "FLOAT,INT,BOOLEAN",
"link": 429
},
{
"label": "c",
"localized_name": "values.c",
"name": "values.c",
"shape": 7,
"type": "FLOAT,INT,BOOLEAN",
"link": null
},
{
"localized_name": "expression",
"name": "expression",
"type": "STRING",
"widget": {
"name": "expression"
},
"link": null
}
],
"outputs": [
{
"localized_name": "FLOAT",
"name": "FLOAT",
"type": "FLOAT",
"links": null
},
{
"localized_name": "INT",
"name": "INT",
"type": "INT",
"links": [
404
]
},
{
"localized_name": "BOOL",
"name": "BOOL",
"type": "BOOLEAN",
"links": null
}
],
"title": "Math Expression Height",
"properties": {
"Node name for S&R": "ComfyMathExpression",
"cnr_id": "comfy-core",
"ver": "0.18.1",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"ue_properties": {
"widget_ue_connectable": {},
"version": "7.7",
"input_ue_unconnectable": {}
}
},
"widgets_values": [
"max(1, (int(a) + int(b) - 1) // int(b))"
]
},
{
"id": 229,
"type": "ComfyMathExpression",
"pos": [
-1090,
-30
],
"size": [
370,
190
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [
{
"label": "a",
"localized_name": "values.a",
"name": "values.a",
"type": "FLOAT,INT,BOOLEAN",
"link": 387
},
{
"label": "b",
"localized_name": "values.b",
"name": "values.b",
"shape": 7,
"type": "FLOAT,INT,BOOLEAN",
"link": 388
},
{
"label": "c",
"localized_name": "values.c",
"name": "values.c",
"shape": 7,
"type": "FLOAT,INT,BOOLEAN",
"link": null
},
{
"localized_name": "expression",
"name": "expression",
"type": "STRING",
"widget": {
"name": "expression"
},
"link": null
}
],
"outputs": [
{
"localized_name": "FLOAT",
"name": "FLOAT",
"type": "FLOAT",
"links": null
},
{
"localized_name": "INT",
"name": "INT",
"type": "INT",
"links": [
403
]
},
{
"localized_name": "BOOL",
"name": "BOOL",
"type": "BOOLEAN",
"links": null
}
],
"title": "Math Expression Width",
"properties": {
"Node name for S&R": "ComfyMathExpression",
"cnr_id": "comfy-core",
"ver": "0.18.1",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"ue_properties": {
"widget_ue_connectable": {},
"version": "7.7",
"input_ue_unconnectable": {}
}
},
"widgets_values": [
"max(1, (int(a) + int(b) - 1) // int(b))"
]
},
{
"id": 228,
"type": "PrimitiveInt",
"pos": [
-1380,
90
],
"size": [
230,
110
],
"flags": {},
"order": 1,
"mode": 0,
"inputs": [
{
"localized_name": "value",
"name": "value",
"type": "INT",
"widget": {
"name": "value"
},
"link": 427
}
],
"outputs": [
{
"localized_name": "INT",
"name": "INT",
"type": "INT",
"links": [
388
]
}
],
"title": "Int (grid columns)",
"properties": {
"Node name for S&R": "Int (grid columns)",
"cnr_id": "comfy-core",
"ver": "0.18.1",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"ue_properties": {
"widget_ue_connectable": {},
"version": "7.7",
"input_ue_unconnectable": {}
}
},
"widgets_values": [
2,
"fixed"
]
},
{
"id": 230,
"type": "GetImageSize",
"pos": [
-1380,
290
],
"size": [
230,
100
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [
{
"localized_name": "image",
"name": "image",
"type": "IMAGE",
"link": 389
}
],
"outputs": [
{
"localized_name": "width",
"name": "width",
"type": "INT",
"links": [
387
]
},
{
"localized_name": "height",
"name": "height",
"type": "INT",
"links": [
390
]
},
{
"localized_name": "batch_size",
"name": "batch_size",
"type": "INT",
"links": null
}
],
"properties": {
"Node name for S&R": "GetImageSize",
"cnr_id": "comfy-core",
"ver": "0.18.1",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"ue_properties": {
"widget_ue_connectable": {},
"version": "7.7",
"input_ue_unconnectable": {}
}
}
},
{
"id": 252,
"type": "PrimitiveInt",
"pos": [
-1380,
470
],
"size": [
230,
110
],
"flags": {},
"order": 5,
"mode": 0,
"inputs": [
{
"localized_name": "value",
"name": "value",
"type": "INT",
"widget": {
"name": "value"
},
"link": 428
}
],
"outputs": [
{
"localized_name": "INT",
"name": "INT",
"type": "INT",
"links": [
429
]
}
],
"title": "Int (grid rows)",
"properties": {
"Node name for S&R": "Int (grid rows)",
"cnr_id": "comfy-core",
"ver": "0.18.1",
"enableTabs": false,
"tabWidth": 65,
"tabXOffset": 10,
"hasSecondTab": false,
"secondTabText": "Send Back",
"secondTabOffset": 80,
"secondTabWidth": 65,
"ue_properties": {
"widget_ue_connectable": {},
"version": "7.7",
"input_ue_unconnectable": {}
}
},
"widgets_values": [
3,
"fixed"
]
}
],
"groups": [],
"links": [
{
"id": 403,
"origin_id": 229,
"origin_slot": 1,
"target_id": 225,
"target_slot": 1,
"type": "INT"
},
{
"id": 404,
"origin_id": 231,
"origin_slot": 1,
"target_id": 225,
"target_slot": 2,
"type": "INT"
},
{
"id": 390,
"origin_id": 230,
"origin_slot": 1,
"target_id": 231,
"target_slot": 0,
"type": "INT"
},
{
"id": 387,
"origin_id": 230,
"origin_slot": 0,
"target_id": 229,
"target_slot": 0,
"type": "INT"
},
{
"id": 388,
"origin_id": 228,
"origin_slot": 0,
"target_id": 229,
"target_slot": 1,
"type": "INT"
},
{
"id": 386,
"origin_id": -10,
"origin_slot": 0,
"target_id": 225,
"target_slot": 0,
"type": "IMAGE"
},
{
"id": 389,
"origin_id": -10,
"origin_slot": 0,
"target_id": 230,
"target_slot": 0,
"type": "IMAGE"
},
{
"id": 394,
"origin_id": 225,
"origin_slot": 0,
"target_id": -20,
"target_slot": 0,
"type": "IMAGE"
},
{
"id": 427,
"origin_id": -10,
"origin_slot": 1,
"target_id": 228,
"target_slot": 0,
"type": "INT"
},
{
"id": 428,
"origin_id": -10,
"origin_slot": 2,
"target_id": 252,
"target_slot": 0,
"type": "INT"
},
{
"id": 429,
"origin_id": 252,
"origin_slot": 0,
"target_id": 231,
"target_slot": 1,
"type": "INT"
}
],
"extra": {},
"category": "Image Tools/Crop",
"description": "Splits an image into a configurable columns×rows grid of equal tiles for tiled generation or processing."
}
]
},
"extra": {}
}

File diff suppressed because it is too large Load Diff

View File

@ -307,7 +307,7 @@
"extra": { "extra": {
"workflowRendererVersion": "LG" "workflowRendererVersion": "LG"
}, },
"category": "Text generation/Video Captioning", "category": "Video Tools",
"description": "Generates descriptive captions for video input using Google's Gemini multimodal LLM." "description": "Generates descriptive captions for video input using Google's Gemini multimodal LLM."
} }
] ]

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -818,7 +818,7 @@
} }
], ],
"extra": {}, "extra": {},
"category": "Video Tools", "category": "Conditioning & Preprocessors/Segmentation & Mask",
"description": "Segments video into temporally consistent masks using Meta SAM3 from text or interactive prompts." "description": "Segments video into temporally consistent masks using Meta SAM3 from text or interactive prompts."
} }
] ]

View File

@ -412,7 +412,7 @@
"extra": { "extra": {
"workflowRendererVersion": "LG" "workflowRendererVersion": "LG"
}, },
"category": "Video generation and editing/Enhance video", "category": "Video generation and editing/Upscale",
"description": "Upscales video to 4× resolution using a GAN-based upscaling model." "description": "Upscales video to 4× resolution using a GAN-based upscaling model."
} }
] ]

File diff suppressed because it is too large Load Diff

View File

@ -55,12 +55,7 @@ class BackgroundRemovalModel():
out = torch.nn.functional.interpolate(out, size=(H, W), mode="bicubic", antialias=False) out = torch.nn.functional.interpolate(out, size=(H, W), mode="bicubic", antialias=False)
mask = out.sigmoid().to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()) mask = out.sigmoid().to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
if mask.ndim == 3: return mask.squeeze(1) # (B, 1, H, W) -> (B, H, W)
mask = mask.unsqueeze(0)
if mask.shape[1] != 1:
mask = mask.movedim(-1, 1)
return mask
def load_background_removal_model(sd): def load_background_removal_model(sd):

View File

@ -49,7 +49,7 @@ parser.add_argument("--temp-directory", type=str, default=None, help="Set the Co
parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.") parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.") parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.") parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use. All other devices will not be visible.") parser.add_argument("--cuda-device", type=str, default=None, metavar="DEVICE_ID", help="Set the ids of cuda devices this instance will use, as a comma-separated list (e.g. '0' or '0,1'). All other devices will not be visible.")
parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.") parser.add_argument("--default-device", type=int, default=None, metavar="DEFAULT_DEVICE_ID", help="Set the id of the default device, all other devices will stay visible.")
cm_group = parser.add_mutually_exclusive_group() cm_group = parser.add_mutually_exclusive_group()
cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).") cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
@ -110,13 +110,11 @@ parser.add_argument("--preview-method", type=LatentPreviewMethod, default=Latent
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.") parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
CACHE_RAM_AUTO_GB = -1.0
cache_group = parser.add_mutually_exclusive_group() cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-ram", nargs='*', type=float, default=[], metavar="GB", help="Use RAM pressure caching with the specified headroom thresholds. This is the default caching mode. The first value sets the active-cache threshold; the optional second value sets the inactive-cache/pin threshold. Defaults when no values are provided: active 10%% of system RAM (min 2GB, max 10GB), inactive 100%% of system RAM (max 96GB).")
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.") cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.") cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.") cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
cache_group.add_argument("--cache-ram", nargs='?', const=CACHE_RAM_AUTO_GB, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threshold the cache removes large items to free RAM. Default (when no value is provided): 25%% of system RAM (min 4GB, max 32GB).")
attn_group = parser.add_mutually_exclusive_group() attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
@ -245,6 +243,9 @@ if comfy.options.args_parsing:
else: else:
args = parser.parse_args([]) args = parser.parse_args([])
if args.cache_ram is not None and len(args.cache_ram) > 2:
parser.error("--cache-ram accepts at most two values: active GB and inactive GB")
if args.windows_standalone_build: if args.windows_standalone_build:
args.auto_launch = True args.auto_launch = True

View File

@ -1,6 +1,5 @@
"""Comfy-specific type hinting""" """Comfy-specific type hinting"""
from __future__ import annotations
from typing import Literal, TypedDict, Optional from typing import Literal, TypedDict, Optional
from typing_extensions import NotRequired from typing_extensions import NotRequired
from abc import ABC, abstractmethod from abc import ABC, abstractmethod

View File

@ -15,13 +15,14 @@
You should have received a copy of the GNU General Public License You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>. along with this program. If not, see <https://www.gnu.org/licenses/>.
""" """
from __future__ import annotations
import torch import torch
from enum import Enum from enum import Enum
import math import math
import os import os
import logging import logging
import copy
import comfy.utils import comfy.utils
import comfy.model_management import comfy.model_management
import comfy.model_detection import comfy.model_detection
@ -38,7 +39,7 @@ import comfy.ldm.hydit.controlnet
import comfy.ldm.flux.controlnet import comfy.ldm.flux.controlnet
import comfy.ldm.qwen_image.controlnet import comfy.ldm.qwen_image.controlnet
import comfy.cldm.dit_embedder import comfy.cldm.dit_embedder
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Union
if TYPE_CHECKING: if TYPE_CHECKING:
from comfy.hooks import HookGroup from comfy.hooks import HookGroup
@ -64,6 +65,18 @@ class StrengthType(Enum):
CONSTANT = 1 CONSTANT = 1
LINEAR_UP = 2 LINEAR_UP = 2
class ControlIsolation:
'''Temporarily set a ControlBase object's previous_controlnet to None to prevent cascading calls.'''
def __init__(self, control: ControlBase):
self.control = control
self.orig_previous_controlnet = control.previous_controlnet
def __enter__(self):
self.control.previous_controlnet = None
def __exit__(self, *args):
self.control.previous_controlnet = self.orig_previous_controlnet
class ControlBase: class ControlBase:
def __init__(self): def __init__(self):
self.cond_hint_original = None self.cond_hint_original = None
@ -77,7 +90,7 @@ class ControlBase:
self.compression_ratio = 8 self.compression_ratio = 8
self.upscale_algorithm = 'nearest-exact' self.upscale_algorithm = 'nearest-exact'
self.extra_args = {} self.extra_args = {}
self.previous_controlnet = None self.previous_controlnet: Union[ControlBase, None] = None
self.extra_conds = [] self.extra_conds = []
self.strength_type = StrengthType.CONSTANT self.strength_type = StrengthType.CONSTANT
self.concat_mask = False self.concat_mask = False
@ -85,6 +98,7 @@ class ControlBase:
self.extra_concat = None self.extra_concat = None
self.extra_hooks: HookGroup = None self.extra_hooks: HookGroup = None
self.preprocess_image = lambda a: a self.preprocess_image = lambda a: a
self.multigpu_clones: dict[torch.device, ControlBase] = {}
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]): def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
self.cond_hint_original = cond_hint self.cond_hint_original = cond_hint
@ -111,17 +125,38 @@ class ControlBase:
def cleanup(self): def cleanup(self):
if self.previous_controlnet is not None: if self.previous_controlnet is not None:
self.previous_controlnet.cleanup() self.previous_controlnet.cleanup()
for device_cnet in self.multigpu_clones.values():
with ControlIsolation(device_cnet):
device_cnet.cleanup()
self.cond_hint = None self.cond_hint = None
self.extra_concat = None self.extra_concat = None
self.timestep_range = None self.timestep_range = None
def get_models(self): def get_models(self):
out = [] out = []
for device_cnet in self.multigpu_clones.values():
out += device_cnet.get_models_only_self()
if self.previous_controlnet is not None: if self.previous_controlnet is not None:
out += self.previous_controlnet.get_models() out += self.previous_controlnet.get_models()
return out return out
def get_models_only_self(self):
'Calls get_models, but temporarily sets previous_controlnet to None.'
with ControlIsolation(self):
return self.get_models()
def get_instance_for_device(self, device):
'Returns instance of this Control object intended for selected device.'
return self.multigpu_clones.get(device, self)
def deepclone_multigpu(self, load_device, autoregister=False):
'''
Create deep clone of Control object where model(s) is set to other devices.
When autoregister is set to True, the deep clone is also added to multigpu_clones dict.
'''
raise NotImplementedError("Classes inheriting from ControlBase should define their own deepclone_multigpu funtion.")
def get_extra_hooks(self): def get_extra_hooks(self):
out = [] out = []
if self.extra_hooks is not None: if self.extra_hooks is not None:
@ -130,7 +165,7 @@ class ControlBase:
out += self.previous_controlnet.get_extra_hooks() out += self.previous_controlnet.get_extra_hooks()
return out return out
def copy_to(self, c): def copy_to(self, c: ControlBase):
c.cond_hint_original = self.cond_hint_original c.cond_hint_original = self.cond_hint_original
c.strength = self.strength c.strength = self.strength
c.timestep_percent_range = self.timestep_percent_range c.timestep_percent_range = self.timestep_percent_range
@ -284,6 +319,14 @@ class ControlNet(ControlBase):
self.copy_to(c) self.copy_to(c)
return c return c
def deepclone_multigpu(self, load_device, autoregister=False):
c = self.copy()
c.control_model = copy.deepcopy(c.control_model)
c.control_model_wrapped = comfy.model_patcher.ModelPatcher(c.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
if autoregister:
self.multigpu_clones[load_device] = c
return c
def get_models(self): def get_models(self):
out = super().get_models() out = super().get_models()
out.append(self.control_model_wrapped) out.append(self.control_model_wrapped)
@ -314,6 +357,10 @@ class QwenFunControlNet(ControlNet):
super().pre_run(model, percent_to_timestep_function) super().pre_run(model, percent_to_timestep_function)
self.set_extra_arg("base_model", model.diffusion_model) self.set_extra_arg("base_model", model.diffusion_model)
def cleanup(self):
self.extra_args.pop("base_model", None)
super().cleanup()
def copy(self): def copy(self):
c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
c.control_model = self.control_model c.control_model = self.control_model
@ -906,6 +953,14 @@ class T2IAdapter(ControlBase):
self.copy_to(c) self.copy_to(c)
return c return c
def deepclone_multigpu(self, load_device, autoregister=False):
c = self.copy()
c.t2i_model = copy.deepcopy(c.t2i_model)
c.device = load_device
if autoregister:
self.multigpu_clones[load_device] = c
return c
def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
compression_ratio = 8 compression_ratio = 8
upscale_algorithm = 'nearest-exact' upscale_algorithm = 'nearest-exact'

View File

@ -1,5 +1,20 @@
import logging
import torch import torch
_CK_STOCHASTIC_ROUNDING_AVAILABLE = False
try:
import comfy_kitchen as ck
_ck_stochastic_rounding_fp8 = ck.stochastic_rounding_fp8
_CK_STOCHASTIC_ROUNDING_AVAILABLE = True
except (AttributeError, ImportError):
logging.warning("comfy_kitchen does not support stochastic FP8 rounding, please update comfy_kitchen.")
if not _CK_STOCHASTIC_ROUNDING_AVAILABLE:
def _ck_stochastic_rounding_fp8(value, rng, dtype):
raise NotImplementedError("comfy_kitchen does not support stochastic FP8 rounding")
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None): def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
mantissa_scaled = torch.where( mantissa_scaled = torch.where(
normal_mask, normal_mask,
@ -57,6 +72,10 @@ def stochastic_rounding(value, dtype, seed=0):
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2: if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
generator = torch.Generator(device=value.device) generator = torch.Generator(device=value.device)
generator.manual_seed(seed) generator.manual_seed(seed)
if _CK_STOCHASTIC_ROUNDING_AVAILABLE:
rng = torch.randint(0, 256, value.size(), dtype=torch.uint8, layout=value.layout, device=value.device, generator=generator)
return _ck_stochastic_rounding_fp8(value, rng, dtype)
output = torch.empty_like(value, dtype=dtype) output = torch.empty_like(value, dtype=dtype)
num_slices = max(1, (value.numel() / (4096 * 4096))) num_slices = max(1, (value.numel() / (4096 * 4096)))
slice_size = max(1, round(value.shape[0] / num_slices)) slice_size = max(1, round(value.shape[0] / num_slices))

View File

@ -799,13 +799,15 @@ class ZImagePixelSpace(ChromaRadiance):
""" """
pass pass
class HiDreamO1Pixel(ChromaRadiance): class HiDreamO1Pixel(ChromaRadiance):
"""Pixel-space latent format for HiDream-O1. """Pixel-space latent format for HiDream-O1.
No VAE model patches/unpatches raw RGB internally with patch_size=32. No VAE model patches/unpatches raw RGB internally with patch_size=32.
""" """
pass pass
class PixelDiTPixel(ChromaRadiance):
pass
class CogVideoX(LatentFormat): class CogVideoX(LatentFormat):
"""Latent format for CogVideoX-2b (THUDM/CogVideoX-2b). """Latent format for CogVideoX-2b (THUDM/CogVideoX-2b).

View File

@ -433,11 +433,11 @@ class Attention(nn.Module):
if self.differential: if self.differential:
q, q_diff = q.unbind(dim=1) q, q_diff = q.unbind(dim=1)
k, k_diff = k.unbind(dim=1) k, k_diff = k.unbind(dim=1)
out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options) out = optimized_attention(q, k, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options)
out_diff = optimized_attention(q_diff, k_diff, v, h, skip_reshape=True, transformer_options=transformer_options) out_diff = optimized_attention(q_diff, k_diff, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options)
out = out - out_diff out = out - out_diff
else: else:
out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options) out = optimized_attention(q, k, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options)
out = self.to_out(out) out = self.to_out(out)

View File

@ -138,11 +138,11 @@ class Attention(nn.Module):
k_diff = _apply_rotary_pos_emb(k_diff.float(), freqs).to(k_dtype) k_diff = _apply_rotary_pos_emb(k_diff.float(), freqs).to(k_dtype)
if self.differential: if self.differential:
out = (optimized_attention(q, k, v, h, mask=mask, skip_reshape=True) out = (optimized_attention(q, k, v, h, mask=mask, skip_reshape=True, low_precision_attention=False)
- optimized_attention(q_diff, k_diff, v, h, mask=mask, skip_reshape=True)) - optimized_attention(q_diff, k_diff, v, h, mask=mask, skip_reshape=True, low_precision_attention=False))
del q, k, v, q_diff, k_diff del q, k, v, q_diff, k_diff
else: else:
out = optimized_attention(q, k, v, h, mask=mask, skip_reshape=True) out = optimized_attention(q, k, v, h, mask=mask, skip_reshape=True, low_precision_attention=False)
del q, k, v del q, k, v
return self.to_out(out) return self.to_out(out)

View File

@ -15,15 +15,6 @@ import comfy.patcher_extension
from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.attention import optimized_attention
import comfy.ldm.common_dit import comfy.ldm.common_dit
def apply_rotary_pos_emb(
t: torch.Tensor,
freqs: torch.Tensor,
) -> torch.Tensor:
t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()
t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]
t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)
return t_out
# ---------------------- Feed Forward Network ----------------------- # ---------------------- Feed Forward Network -----------------------
class GPT2FeedForward(nn.Module): class GPT2FeedForward(nn.Module):
@ -173,8 +164,7 @@ class Attention(nn.Module):
k = self.k_norm(k) k = self.k_norm(k)
v = self.v_norm(v) v = self.v_norm(v)
if self.is_selfattn and rope_emb is not None: # only apply to self-attention! if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
q = apply_rotary_pos_emb(q, rope_emb) q, k = comfy.quant_ops.ck.apply_rope_split_half(q, k, rope_emb)
k = apply_rotary_pos_emb(k, rope_emb)
return q, k, v return q, k, v
q, k, v = apply_norm_and_rotary_pos_emb(q, k, v, rope_emb) q, k, v = apply_norm_and_rotary_pos_emb(q, k, v, rope_emb)

View File

@ -607,9 +607,13 @@ class HunYuanDiTPlain(nn.Module):
def forward(self, x, t, context, transformer_options = {}, **kwargs): def forward(self, x, t, context, transformer_options = {}, **kwargs):
x = x.movedim(-1, -2) x = x.movedim(-1, -2)
if context.shape[0] >= 2:
uncond_emb, cond_emb = context.chunk(2, dim = 0) swap_cfg_halves = context.shape[0] >= 2
context = torch.cat([cond_emb, uncond_emb], dim = 0)
if swap_cfg_halves:
first_half, second_half = context.chunk(2, dim = 0)
context = torch.cat([second_half, first_half], dim = 0)
main_condition = context main_condition = context
t = 1.0 - t t = 1.0 - t
@ -657,8 +661,8 @@ class HunYuanDiTPlain(nn.Module):
output = self.final_layer(combined) output = self.final_layer(combined)
output = output.movedim(-2, -1) * (-1.0) output = output.movedim(-2, -1) * (-1.0)
if output.shape[0] >= 2: if swap_cfg_halves:
cond_emb, uncond_emb = output.chunk(2, dim = 0) first_half, second_half = output.chunk(2, dim = 0)
return torch.cat([uncond_emb, cond_emb]) output = torch.cat([second_half, first_half], dim = 0)
else:
return output return output

510
comfy/ldm/lens/model.py Normal file
View File

@ -0,0 +1,510 @@
"""Lens denoising transformer (DiT)"""
from __future__ import annotations
from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.ldm.flux.layers
import comfy.patcher_extension
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope
from comfy.ldm.modules.attention import optimized_attention
def _lens_time_proj(t: torch.Tensor, dim: int = 256) -> torch.Tensor:
return comfy.ldm.flux.layers.timestep_embedding(t, dim)
def _lens_position_ids(
frame: int, height: int, width: int, text_seq_len: int,
scale_rope: bool = True, device=None,
) -> torch.Tensor:
"""Lens axial (frame, h, w) position ids for joint image + text sequence.
With ``scale_rope=True`` h/w are centered around 0 (negative + positive
halves) and text starts at ``max(h//2, w//2)``. Result shape ``[seq, 3]``;
caller adds a batch dim for ``EmbedND``.
"""
if scale_rope:
h_pos = torch.cat([torch.arange(-(height - height // 2), 0, device=device),
torch.arange(0, height // 2, device=device)])
w_pos = torch.cat([torch.arange(-(width - width // 2), 0, device=device),
torch.arange(0, width // 2, device=device)])
text_start = max(height // 2, width // 2)
else:
h_pos = torch.arange(height, device=device)
w_pos = torch.arange(width, device=device)
text_start = max(height, width)
f_pos = torch.arange(frame, device=device)
img_ids = torch.zeros(frame, height, width, 3, device=device)
img_ids[..., 0] = f_pos[:, None, None]
img_ids[..., 1] = h_pos[None, :, None]
img_ids[..., 2] = w_pos[None, None, :]
img_ids = img_ids.reshape(-1, 3)
# Text positions replicate across all 3 axes (matches original packing).
txt_pos = torch.arange(text_start, text_start + text_seq_len, device=device).float()
txt_ids = txt_pos[:, None].expand(text_seq_len, 3)
return torch.cat([img_ids, txt_ids], dim=0)
class _TimestepEmbedder(nn.Module):
def __init__(self, in_channels: int, time_embed_dim: int, dtype=None, device=None, operations=None) -> None:
super().__init__()
self.linear_1 = operations.Linear(in_channels, time_embed_dim, dtype=dtype, device=device)
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear_1(x)
x = F.silu(x)
return self.linear_2(x)
class LensTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim: int, dtype=None, device=None, operations=None) -> None:
super().__init__()
self.timestep_embedder = _TimestepEmbedder(256, embedding_dim, dtype=dtype, device=device, operations=operations)
def forward(self, timestep: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor:
proj = _lens_time_proj(timestep, 256)
return self.timestep_embedder(proj.to(dtype=hidden_states.dtype))
class GateMLP(nn.Module):
"""SwiGLU MLP."""
def __init__(self, dim: int, hidden_dim: int, dtype=None, device=None, operations=None) -> None:
super().__init__()
self.w1 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
self.w2 = operations.Linear(hidden_dim, dim, bias=False, dtype=dtype, device=device)
self.w3 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
def forward(self, x):
return self.w2(F.silu(self.w1(x), inplace=True).mul_(self.w3(x)))
class LensJointAttention(nn.Module):
"""Joint image+text attention with fused QKV per stream."""
def __init__(
self,
query_dim: int,
added_kv_proj_dim: int,
dim_head: int = 64,
heads: int = 8,
out_dim: Optional[int] = None,
eps: float = 1e-5,
dtype=None,
device=None,
operations=None,
) -> None:
super().__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.heads = self.inner_dim // dim_head
self.dim_head = dim_head
self.out_dim = out_dim if out_dim is not None else query_dim
self.norm_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
self.norm_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
self.norm_added_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
self.norm_added_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
self.img_qkv = operations.Linear(query_dim, 3 * self.inner_dim, bias=True, dtype=dtype, device=device)
self.txt_qkv = operations.Linear(added_kv_proj_dim, 3 * self.inner_dim, bias=True, dtype=dtype, device=device)
# ModuleList([Linear, Identity]) for state-dict key compatibility.
self.to_out = nn.ModuleList([
operations.Linear(self.inner_dim, self.out_dim, bias=True, dtype=dtype, device=device),
nn.Identity(),
])
self.to_add_out = operations.Linear(self.inner_dim, query_dim, bias=True, dtype=dtype, device=device)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
freqs_cis: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
transformer_options: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
bsz, seq_img, _ = hidden_states.shape
seq_txt = encoder_hidden_states.shape[1]
# image stream
img_qkv = self.img_qkv(hidden_states).view(bsz, seq_img, 3, self.heads, self.dim_head)
img_q, img_k, img_v = img_qkv.unbind(dim=2)
img_q = self.norm_q(img_q)
img_k = self.norm_k(img_k)
del img_qkv
# text stream
txt_qkv = self.txt_qkv(encoder_hidden_states).view(bsz, seq_txt, 3, self.heads, self.dim_head)
txt_q, txt_k, txt_v = txt_qkv.unbind(dim=2)
txt_q = self.norm_added_q(txt_q)
txt_k = self.norm_added_k(txt_k)
# [B, S, H, D] → [B, H, S, D] for attention, dels to avoid VRAM peaks
q = torch.cat([img_q, txt_q], dim=1).transpose(1, 2)
del img_q, txt_q
k = torch.cat([img_k, txt_k], dim=1).transpose(1, 2)
del img_k, txt_k
v = torch.cat([img_v, txt_v], dim=1).transpose(1, 2)
del img_v, txt_v
q, k = apply_rope(q, k, freqs_cis)
if attention_mask is not None:
expected = (bsz, 1, 1, seq_img + seq_txt)
if attention_mask.shape != expected:
raise ValueError(
f"attention_mask must be {expected}, got {tuple(attention_mask.shape)}"
)
attention_mask = attention_mask.to(q.dtype)
out = optimized_attention(
q, k, v, self.heads, mask=attention_mask, skip_reshape=True,
transformer_options=transformer_options,
)
img_out = self.to_out[1](self.to_out[0](out[:, :seq_img, :]))
txt_out = self.to_add_out(out[:, seq_img:, :])
return img_out, txt_out
class LensTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
eps: float = 1e-6,
rms_norm: bool = True,
dtype=None,
device=None,
operations=None,
) -> None:
super().__init__()
self.attn = LensJointAttention(
query_dim=dim,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
eps=1e-5,
dtype=dtype,
device=device,
operations=operations,
)
if rms_norm:
NormCls = operations.RMSNorm
norm_kwargs = {}
else:
NormCls = operations.LayerNorm
norm_kwargs = {"elementwise_affine": False}
mlp_hidden = int(dim / 3 * 8)
# Sequential(SiLU, Linear) so state-dict lands at img_mod.1.{weight,bias}.
self.img_mod = nn.Sequential(
nn.SiLU(),
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
)
self.img_norm1 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
self.img_norm2 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
self.img_mlp = GateMLP(dim, mlp_hidden, dtype=dtype, device=device, operations=operations)
self.txt_mod = nn.Sequential(
nn.SiLU(),
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
)
self.txt_norm1 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
self.txt_norm2 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
self.txt_mlp = GateMLP(dim, mlp_hidden, dtype=dtype, device=device, operations=operations)
@staticmethod
def _modulate(x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
shift, scale, gate = mod_params.chunk(3, dim=-1)
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
freqs_cis: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
transformer_options: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod1, img_mod2 = self.img_mod(temb).chunk(2, dim=-1)
txt_mod1, txt_mod2 = self.txt_mod(temb).chunk(2, dim=-1)
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
img_attn, txt_attn = self.attn(
hidden_states=img_modulated,
encoder_hidden_states=txt_modulated,
freqs_cis=freqs_cis,
attention_mask=attention_mask,
transformer_options=transformer_options,
)
hidden_states = hidden_states + img_gate1 * img_attn
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
hidden_states = hidden_states + img_gate2 * self.img_mlp(img_modulated2)
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
encoder_hidden_states = encoder_hidden_states + txt_gate2 * self.txt_mlp(txt_modulated2)
return encoder_hidden_states, hidden_states
class _AdaLayerNormContinuousNoAffine(nn.Module):
"""AdaLayerNormContinuous(elementwise_affine=False).
The reference uses ``scale, shift = chunk(2)`` (scale first) opposite
to Flux's ``LastLayer``.
"""
def __init__(self, embedding_dim: int, conditioning_embedding_dim: int, eps: float = 1e-6,
dtype=None, device=None, operations=None) -> None:
super().__init__()
self.linear = operations.Linear(
conditioning_embedding_dim, embedding_dim * 2, bias=True, dtype=dtype, device=device
)
self.eps = eps
self.embedding_dim = embedding_dim
def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
emb = self.linear(F.silu(conditioning))
scale, shift = torch.chunk(emb, 2, dim=-1)
x = F.layer_norm(x, (self.embedding_dim,), None, None, self.eps)
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class LensTransformer2DModel(nn.Module):
"""Lens dual-stream MMDiT (48 blocks, inner_dim=1536, multi-layer text)."""
def __init__(
self,
patch_size: int = 2,
in_channels: int = 128,
out_channels: Optional[int] = 32,
num_layers: int = 48,
attention_head_dim: int = 64,
num_attention_heads: int = 24,
enc_hidden_dim: int = 2880,
axes_dims_rope: Tuple[int, int, int] = (8, 28, 28),
rms_norm: bool = True,
multi_layer_encoder_feature: bool = True,
selected_layer_index: Tuple[int, ...] = (5, 11, 17, 23),
image_model=None, # unused; accepted for detection-side configs.
dtype=None,
device=None,
operations=None,
) -> None:
super().__init__()
self.patch_size = patch_size
self.in_channels = in_channels
self.out_channels = out_channels if out_channels is not None else in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.multi_layer_encoder_feature = multi_layer_encoder_feature
self.selected_layer_index = list(selected_layer_index)
self.dtype = dtype
self.pos_embed = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
self.time_text_embed = LensTimestepProjEmbeddings(
embedding_dim=self.inner_dim, dtype=dtype, device=device, operations=operations
)
if self.multi_layer_encoder_feature:
self.txt_norm = nn.ModuleList(
[operations.RMSNorm(enc_hidden_dim, eps=1e-5, dtype=dtype, device=device)
for _ in self.selected_layer_index]
)
self.txt_in = operations.Linear(
enc_hidden_dim * len(self.selected_layer_index),
self.inner_dim, bias=True, dtype=dtype, device=device,
)
else:
self.txt_norm = operations.RMSNorm(enc_hidden_dim, eps=1e-5, dtype=dtype, device=device)
self.txt_in = operations.Linear(enc_hidden_dim, self.inner_dim, bias=True, dtype=dtype, device=device)
self.img_in = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
self.transformer_blocks = nn.ModuleList([
LensTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
eps=1e-6,
rms_norm=rms_norm,
dtype=dtype, device=device, operations=operations,
)
for _ in range(num_layers)
])
self.norm_out = _AdaLayerNormContinuousNoAffine(
self.inner_dim, self.inner_dim, eps=1e-6,
dtype=dtype, device=device, operations=operations,
)
self.proj_out = operations.Linear(
self.inner_dim, patch_size * patch_size * self.out_channels, bias=True,
dtype=dtype, device=device,
)
def forward(self, x: torch.Tensor, timestep: torch.Tensor, context: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
transformer_options: Optional[Dict[str, Any]] = None, **kwargs) -> torch.Tensor:
if transformer_options is None:
transformer_options = {}
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward, self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options),
).execute(x, timestep, context, attention_mask, transformer_options, **kwargs)
def _forward(
self,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
transformer_options: Optional[Dict[str, Any]] = None,
control: Optional[Dict[str, Any]] = None,
**kwargs,
) -> torch.Tensor:
"""ComfyUI bridge: ``(x[B,128,h,w], t[B], context[B,S,L*H], mask[B,S])``."""
if transformer_options is None:
transformer_options = {}
transformer_options = transformer_options.copy()
patches = transformer_options.get("patches", {})
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
B, C, h, w = x.shape
hidden_states = x.permute(0, 2, 3, 1).reshape(B, h * w, C)
if self.multi_layer_encoder_feature:
L = len(self.selected_layer_index)
enc_dim = context.shape[-1] // L
encoder_hidden_states = list(
context.reshape(B, -1, L, enc_dim).unbind(dim=2)
)
text_seq_len = encoder_hidden_states[0].shape[1]
else:
encoder_hidden_states = context
text_seq_len = context.shape[1]
if attention_mask is None:
attention_mask = torch.ones(
(B, text_seq_len), dtype=torch.bool, device=x.device
)
img_len = h * w
joint_mask = self._build_joint_attention_mask(attention_mask, img_len)
hidden_states = self.img_in(hidden_states)
timestep = timestep.to(hidden_states.dtype)
if self.multi_layer_encoder_feature:
normed = [self.txt_norm[i](encoder_hidden_states[i]) for i in range(L)]
encoder_hidden_states = torch.cat(normed, dim=-1)
else:
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
encoder_hidden_states = self.txt_in(encoder_hidden_states)
if "post_input" in patches:
for p in patches["post_input"]:
out = p({
"img": hidden_states,
"txt": encoder_hidden_states,
"transformer_options": transformer_options,
})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]
temb = self.time_text_embed(timestep, hidden_states)
ids = _lens_position_ids(1, h, w, text_seq_len, device=hidden_states.device).unsqueeze(0)
freqs_cis = self.pos_embed(ids)
transformer_options["total_blocks"] = len(self.transformer_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.transformer_blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["txt"], out["img"] = block(
hidden_states=args["img"],
encoder_hidden_states=args["txt"],
temb=args["vec"],
freqs_cis=args["pe"],
attention_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"),
)
return out
out = blocks_replace[("double_block", i)](
{
"img": hidden_states,
"txt": encoder_hidden_states,
"vec": temb,
"pe": freqs_cis,
"attn_mask": joint_mask,
"transformer_options": transformer_options,
},
{"original_block": block_wrap},
)
encoder_hidden_states = out["txt"]
hidden_states = out["img"]
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
freqs_cis=freqs_cis,
attention_mask=joint_mask,
transformer_options=transformer_options,
)
if "double_block" in patches:
for p in patches["double_block"]:
out = p({
"img": hidden_states,
"txt": encoder_hidden_states,
"x": x,
"block_index": i,
"transformer_options": transformer_options,
})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]
if control is not None:
control_i = control.get("input")
if control_i is not None and i < len(control_i):
add = control_i[i]
if add is not None:
hidden_states[:, :add.shape[1]] += add
hidden_states = self.norm_out(hidden_states, temb)
out = self.proj_out(hidden_states)
return out.reshape(B, h, w, C).permute(0, 3, 1, 2).contiguous()
@staticmethod
def _build_joint_attention_mask(text_mask: torch.Tensor, img_len: int) -> torch.Tensor:
if text_mask.dtype != torch.bool:
text_mask = text_mask.bool()
bsz = text_mask.shape[0]
img_ones = torch.ones((bsz, img_len), dtype=torch.bool, device=text_mask.device)
joint = torch.cat([img_ones, text_mask], dim=1)
additive = torch.zeros_like(joint, dtype=torch.float32)
additive.masked_fill_(~joint, torch.finfo(torch.float32).min)
return additive[:, None, None, :]

View File

@ -767,25 +767,25 @@ class LTXAVModel(LTXVModel):
# Cross-attention timesteps - compress these too # Cross-attention timesteps - compress these too
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single( av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
timestep.max().expand_as(a_timestep_flat), a_timestep_flat,
{"resolution": None, "aspect_ratio": None}, {"resolution": None, "aspect_ratio": None},
batch_size=batch_size, batch_size=batch_size,
hidden_dtype=hidden_dtype, hidden_dtype=hidden_dtype,
) )
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single( av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
a_timestep.max().expand_as(timestep_flat), timestep_flat,
{"resolution": None, "aspect_ratio": None}, {"resolution": None, "aspect_ratio": None},
batch_size=batch_size, batch_size=batch_size,
hidden_dtype=hidden_dtype, hidden_dtype=hidden_dtype,
) )
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single( av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
a_timestep.max().expand_as(timestep_flat) * av_ca_factor, a_timestep_scaled.max().expand_as(timestep_flat) * av_ca_factor,
{"resolution": None, "aspect_ratio": None}, {"resolution": None, "aspect_ratio": None},
batch_size=batch_size, batch_size=batch_size,
hidden_dtype=hidden_dtype, hidden_dtype=hidden_dtype,
) )
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single( av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
timestep.max().expand_as(a_timestep_flat) * av_ca_factor, timestep_scaled.max().expand_as(a_timestep_flat) * av_ca_factor,
{"resolution": None, "aspect_ratio": None}, {"resolution": None, "aspect_ratio": None},
batch_size=batch_size, batch_size=batch_size,
hidden_dtype=hidden_dtype, hidden_dtype=hidden_dtype,

View File

@ -1,4 +1,3 @@
from __future__ import annotations
import torch import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F

View File

@ -1,4 +1,3 @@
from __future__ import annotations
import threading import threading
import torch import torch
from torch import nn from torch import nn

View File

@ -1,5 +1,4 @@
# Code from: https://github.com/Alpha-VLLM/Lumina-Image-2.0/blob/main/models/model.py # Code from: https://github.com/Alpha-VLLM/Lumina-Image-2.0/blob/main/models/model.py
from __future__ import annotations
from typing import List, Optional, Tuple from typing import List, Optional, Tuple

View File

@ -741,12 +741,12 @@ optimized_attention = attention_basic
if model_management.sage_attention_enabled(): if model_management.sage_attention_enabled():
logging.info("Using sage attention") logging.info("Using sage attention")
optimized_attention = attention_sage optimized_attention = attention_sage
elif model_management.xformers_enabled():
logging.info("Using xformers attention")
optimized_attention = attention_xformers
elif model_management.flash_attention_enabled(): elif model_management.flash_attention_enabled():
logging.info("Using Flash Attention") logging.info("Using Flash Attention")
optimized_attention = attention_flash optimized_attention = attention_flash
elif model_management.xformers_enabled():
logging.info("Using xformers attention")
optimized_attention = attention_xformers
elif model_management.pytorch_attention_enabled(): elif model_management.pytorch_attention_enabled():
logging.info("Using pytorch attention") logging.info("Using pytorch attention")
optimized_attention = attention_pytorch optimized_attention = attention_pytorch

View File

@ -211,7 +211,7 @@ class TimestepEmbedder(nn.Module):
Embeds scalar timesteps into vector representations. Embeds scalar timesteps into vector representations.
""" """
def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None): def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None, max_period=10000):
super().__init__() super().__init__()
if output_size is None: if output_size is None:
output_size = hidden_size output_size = hidden_size
@ -221,9 +221,10 @@ class TimestepEmbedder(nn.Module):
operations.Linear(hidden_size, output_size, bias=True, dtype=dtype, device=device), operations.Linear(hidden_size, output_size, bias=True, dtype=dtype, device=device),
) )
self.frequency_embedding_size = frequency_embedding_size self.frequency_embedding_size = frequency_embedding_size
self.max_period = max_period
def forward(self, t, dtype, **kwargs): def forward(self, t, dtype, **kwargs):
t_freq = timestep_embedding(t, self.frequency_embedding_size).to(dtype) t_freq = timestep_embedding(t, self.frequency_embedding_size, max_period=self.max_period).to(dtype)
t_emb = self.mlp(t_freq) t_emb = self.mlp(t_freq)
return t_emb return t_emb

View File

@ -1,6 +1,5 @@
"""Pure-torch + scipy geometry helpers for MoGe inference and mesh export.""" """Pure-torch + scipy geometry helpers for MoGe inference and mesh export."""
from __future__ import annotations
from typing import Optional, Tuple from typing import Optional, Tuple

View File

@ -4,7 +4,6 @@ V1: DINOv2 backbone + multi-output head (points, mask).
V2: DINOv2 encoder + neck + per-output heads (points, mask, normal, optional metric-scale MLP). V2: DINOv2 encoder + neck + per-output heads (points, mask, normal, optional metric-scale MLP).
""" """
from __future__ import annotations
from numbers import Number from numbers import Number
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union

View File

@ -1,6 +1,5 @@
"""Building blocks for MoGe: residual conv stack, resamplers, MLP, DINOv2 encoder, v1 head.""" """Building blocks for MoGe: residual conv stack, resamplers, MLP, DINOv2 encoder, v1 head."""
from __future__ import annotations
from typing import List, Optional, Sequence, Tuple, Union from typing import List, Optional, Sequence, Tuple, Union

View File

@ -6,7 +6,6 @@ equirect distance map via a multi-scale Poisson + gradient sparse solve.
Image sampling uses F.grid_sample (GPU); the sparse solve uses lsmr (CPU). Image sampling uses F.grid_sample (GPU); the sparse solve uses lsmr (CPU).
""" """
from __future__ import annotations
from typing import Callable, List, Optional, Tuple from typing import Callable, List, Optional, Tuple

239
comfy/ldm/pixeldit/model.py Normal file
View File

@ -0,0 +1,239 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.ldm.common_dit
import comfy.patcher_extension
from comfy.ldm.flux.math import apply_rope, rope
from comfy.ldm.hidream.model import FeedForwardSwiGLU
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
from .modules import (
FinalLayer,
PatchTokenEmbedder,
PiTBlock,
PixelTokenEmbedder,
apply_adaln_,
precompute_freqs_cis_2d,
)
class MMDiTJointAttention(nn.Module):
"""Joint MMDiT attention with separate Q/K/V/proj for image and text streams.
RoPE is applied to each stream before concatenation so each stream uses its own
2D/1D positional encoding. Concat order is [text, image] (text first).
"""
def __init__(self, dim, num_heads=8, qkv_bias=False, dtype=None, device=None, operations=None):
super().__init__()
assert dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv_x = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.qkv_y = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.q_norm_x = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
self.k_norm_x = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
self.q_norm_y = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
self.k_norm_y = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
self.proj_x = operations.Linear(dim, dim, dtype=dtype, device=device)
self.proj_y = operations.Linear(dim, dim, dtype=dtype, device=device)
def forward(self, x, y, pos_img, pos_txt=None, attn_mask=None, transformer_options={}):
B, Nx, _ = x.shape
_, Ny, _ = y.shape
H = self.num_heads
D = self.head_dim
qkv_x = self.qkv_x(x).reshape(B, Nx, 3, H, D).permute(2, 0, 3, 1, 4)
qx, kx, vx = qkv_x.unbind(0)
qx = self.q_norm_x(qx)
kx = self.k_norm_x(kx)
qkv_y = self.qkv_y(y).reshape(B, Ny, 3, H, D).permute(2, 0, 3, 1, 4)
qy, ky, vy = qkv_y.unbind(0)
qy = self.q_norm_y(qy)
ky = self.k_norm_y(ky)
qx, kx = apply_rope(qx, kx, pos_img[None, None])
if pos_txt is not None:
qy, ky = apply_rope(qy, ky, pos_txt[None, None])
q_joint = torch.cat([qy, qx], dim=2)
k_joint = torch.cat([ky, kx], dim=2)
v_joint = torch.cat([vy, vx], dim=2)
out_joint = optimized_attention(
q_joint, k_joint, v_joint, H,
mask=attn_mask, skip_reshape=True, skip_output_reshape=True,
transformer_options=transformer_options,
)
out_y = out_joint[:, :, :Ny, :].transpose(1, 2).reshape(B, Ny, H * D)
out_x = out_joint[:, :, Ny:, :].transpose(1, 2).reshape(B, Nx, H * D)
return self.proj_x(out_x), self.proj_y(out_y)
class MMDiTBlockT2I(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, dtype=None, device=None, operations=None):
super().__init__()
self.norm_x1 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
self.norm_y1 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
self.attn = MMDiTJointAttention(hidden_size, num_heads=groups, qkv_bias=False, dtype=dtype, device=device, operations=operations)
self.norm_x2 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
self.norm_y2 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp_x = FeedForwardSwiGLU(hidden_size, mlp_hidden_dim, multiple_of=1, dtype=dtype, device=device, operations=operations)
self.mlp_y = FeedForwardSwiGLU(hidden_size, mlp_hidden_dim, multiple_of=1, dtype=dtype, device=device, operations=operations)
self.adaLN_modulation_img = nn.Sequential(operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device))
self.adaLN_modulation_txt = nn.Sequential(operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device))
def forward(self, x, y, c, pos_img, pos_txt=None, attn_mask=None, transformer_options={}):
shift_msa_x, scale_msa_x, gate_msa_x, shift_mlp_x, scale_mlp_x, gate_mlp_x = self.adaLN_modulation_img(c).chunk(6, dim=-1)
shift_msa_y, scale_msa_y, gate_msa_y, shift_mlp_y, scale_mlp_y, gate_mlp_y = self.adaLN_modulation_txt(c).chunk(6, dim=-1)
x_norm = apply_adaln_(self.norm_x1(x), shift_msa_x, scale_msa_x)
y_norm = apply_adaln_(self.norm_y1(y), shift_msa_y, scale_msa_y)
attn_x, attn_y = self.attn(x_norm, y_norm, pos_img, pos_txt, attn_mask, transformer_options=transformer_options)
x = torch.addcmul(x, gate_msa_x, attn_x)
y = torch.addcmul(y, gate_msa_y, attn_y)
x = torch.addcmul(x, gate_mlp_x, self.mlp_x(apply_adaln_(self.norm_x2(x), shift_mlp_x, scale_mlp_x)))
y = torch.addcmul(y, gate_mlp_y, self.mlp_y(apply_adaln_(self.norm_y2(y), shift_mlp_y, scale_mlp_y)))
return x, y
class PixDiT_T2I(nn.Module):
"""PixelDiT T2I model. Hardcoded for the released 1024px Stage-3 checkpoint
(also runs at 512px when fed the appropriate latent size and flow_shift).
Forward:
x: [B, 3, H, W] pixel-space input (no VAE)
timesteps:[B] in [0, 1000] (ComfyUI flow sampling convention)
context: [B, Ltxt, 2304] Gemma-2-2b-it hidden states (chi_prompt prepended)
Returns flow-matching velocity [B, 3, H, W].
"""
def __init__(
self,
in_channels=3,
num_groups=24,
hidden_size=1536,
pixel_hidden_size=16,
pixel_attn_hidden_size=1152,
pixel_num_groups=16,
patch_depth=14,
pixel_depth=2,
patch_size=16,
txt_embed_dim=2304,
txt_max_length=300,
use_text_rope=True,
text_rope_theta=10000.0,
image_model=None,
dtype=None,
device=None,
operations=None,
pixel_mlp_chunks=2,
):
super().__init__()
self.dtype = dtype
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.patch_depth = patch_depth
self.pixel_depth = pixel_depth
self.patch_size = patch_size
self.pixel_hidden_size = pixel_hidden_size
self.pixel_attn_hidden_size = pixel_attn_hidden_size
self.pixel_num_groups = pixel_num_groups
self.txt_embed_dim = txt_embed_dim
self.txt_max_length = txt_max_length
self.use_text_rope = use_text_rope
self.text_rope_theta = text_rope_theta
self.pixel_embedder = PixelTokenEmbedder(self.in_channels, self.pixel_hidden_size, dtype=dtype, device=device, operations=operations)
self.s_embedder = PatchTokenEmbedder(self.in_channels * self.patch_size ** 2, self.hidden_size, bias=True, dtype=dtype, device=device, operations=operations)
self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations, max_period=10)
self.y_embedder = PatchTokenEmbedder(self.txt_embed_dim, self.hidden_size, bias=True, use_norm=True, dtype=dtype, device=device, operations=operations)
self.y_pos_embedding = nn.Parameter(torch.empty(1, self.txt_max_length, self.hidden_size, dtype=dtype, device=device))
self.patch_blocks = nn.ModuleList([
MMDiTBlockT2I(self.hidden_size, self.num_groups,
dtype=dtype, device=device, operations=operations)
for _ in range(self.patch_depth)
])
self.pixel_blocks = nn.ModuleList([
PiTBlock(
self.pixel_hidden_size,
self.hidden_size,
patch_size=self.patch_size,
num_heads=self.num_groups,
attn_hidden_size=self.pixel_attn_hidden_size,
attn_num_heads=self.pixel_num_groups,
dtype=dtype, device=device, operations=operations,
mlp_chunks=pixel_mlp_chunks,
)
for _ in range(self.pixel_depth)
])
self.final_layer = FinalLayer(self.pixel_hidden_size, self.out_channels, dtype=dtype, device=device, operations=operations)
def _fetch_patch_pos(self, height, width, device, dtype, **rope_opts):
return precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width, device=device, dtype=dtype, **rope_opts)
def _fetch_text_pos(self, length, device, dtype):
return rope(torch.arange(length, dtype=torch.float32, device=device).reshape(1, -1), self.hidden_size // self.num_groups, self.text_rope_theta).squeeze(0).to(dtype=dtype)
def forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward, self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options),
).execute(x, timesteps, context, attention_mask, transformer_options, **kwargs)
def _pre_patch_block(self, s, i, **kwargs):
"""Hook for subclasses to inject per-block state into the patch stream (e.g. PiD's LQ gate)."""
return s
def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs):
H_orig, W_orig = x.shape[2], x.shape[3]
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
B, _, H, W = x.shape
Hs = H // self.patch_size
Ws = W // self.patch_size
L = Hs * Ws
pos_img = self._fetch_patch_pos(Hs, Ws, x.device, x.dtype, **(transformer_options.get("rope_options") or {}))
x_patches = F.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
t_emb = self.t_embedder(timesteps.view(-1), x.dtype).view(B, -1, self.hidden_size)
if context is None or context.dim() != 3:
raise ValueError("PixDiT_T2I requires context (text embeddings) of shape [B, L, D]")
Ltxt = min(context.shape[1], self.txt_max_length)
y = context[:, :Ltxt, :]
y_emb = self.y_embedder(y).view(B, Ltxt, self.hidden_size)
y_emb = y_emb + self.y_pos_embedding[:, :Ltxt, :].to(y_emb) # y_pos_embedding is a raw nn.Parameter
condition = F.silu(t_emb)
pos_txt = self._fetch_text_pos(Ltxt, x.device, x.dtype) if self.use_text_rope else None
s = self.s_embedder(x_patches)
for i, blk in enumerate(self.patch_blocks):
s = self._pre_patch_block(s, i, **kwargs)
s, y_emb = blk(s, y_emb, condition, pos_img, pos_txt, None, transformer_options=transformer_options)
s = F.silu(t_emb + s)
s_cond = s.view(B * L, self.hidden_size)
x_pixels = self.pixel_embedder(x, patch_size=self.patch_size)
for blk in self.pixel_blocks:
x_pixels = blk(x_pixels, s_cond, H, W, self.patch_size, mask=None, transformer_options=transformer_options)
x_pixels = self.final_layer(x_pixels)
C_out = self.out_channels
P2 = self.patch_size * self.patch_size
x_pixels = x_pixels.view(B, L, P2, C_out).permute(0, 3, 2, 1).reshape(B, C_out * P2, L)
out = F.fold(x_pixels, (H, W), kernel_size=self.patch_size, stride=self.patch_size)
return out[:, :, :H_orig, :W_orig]

View File

@ -0,0 +1,187 @@
import torch
import torch.nn as nn
from comfy.ldm.flux.math import apply_rope, rope
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, get_1d_sincos_pos_embed_from_grid_torch
def apply_adaln_(x, shift, scale):
return x.addcmul_(x, scale).add_(shift)
def precompute_freqs_cis_2d(dim, height, width, theta=10000.0, scale=16.0,
ref_grid_h=None, ref_grid_w=None,
scale_x=1.0, scale_y=1.0, shift_x=0.0, shift_y=0.0,
device=None, dtype=torch.float32, **kwargs):
"""2D RoPE with x/y axis frequencies interleaved at stride 2 across head dim.
rope_options:
scale_x / scale_y multiply the position range (RoPE extrapolation).
shift_x / shift_y offset the position origin (tiled / regional inference).
With ref_grid_h/w set, also applies NTK-aware per-axis theta scaling
(rope_mode='ntk_aware'): theta_axis = theta * (current/ref)^(dim_axis/(dim_axis-2)).
Returns Flux-format rotation matrices of shape [H*W, dim/2, 2, 2].
Layout of head-dim pairs: [x_0, y_0, x_1, y_1, ..., x_{dim/4-1}, y_{dim/4-1}].
"""
dim_axis = dim // 2
if ref_grid_h is not None and dim_axis > 2:
h_ntk = (height / ref_grid_h) ** (dim_axis / (dim_axis - 2))
w_ntk = (width / ref_grid_w) ** (dim_axis / (dim_axis - 2))
else:
h_ntk = w_ntk = 1.0
x_lin = torch.linspace(shift_x, scale * scale_x + shift_x, width, device=device)
y_lin = torch.linspace(shift_y, scale * scale_y + shift_y, height, device=device)
y_grid, x_grid = torch.meshgrid(y_lin, x_lin, indexing="ij")
x_rope = rope(x_grid.reshape(1, -1), dim_axis, theta * w_ntk).squeeze(0)
y_rope = rope(y_grid.reshape(1, -1), dim_axis, theta * h_ntk).squeeze(0)
out = torch.stack([x_rope, y_rope], dim=2).reshape(height * width, dim // 2, 2, 2)
return out.to(dtype=dtype)
def get_2d_sincos_pos_embed(embed_dim, height, width, device=None, dtype=torch.float32):
"""Standard 2D sin/cos absolute positional embedding (ViT-style).
first half encodes W-coordinates, second half H.
"""
assert embed_dim % 4 == 0
grid_h = torch.arange(height, dtype=torch.float32, device=device)
grid_w = torch.arange(width, dtype=torch.float32, device=device)
grid_y, grid_x = torch.meshgrid(grid_h, grid_w, indexing="ij")
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_x.reshape(-1), device=device)
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_y.reshape(-1), device=device)
return torch.cat([emb_w, emb_h], dim=1).to(dtype=dtype)
class RotaryAttention(nn.Module):
"""Single-stream self-attention with rotary positional encoding (used inside PiTBlock)."""
def __init__(self, dim, num_heads=8, qkv_bias=False, dtype=None, device=None, operations=None):
super().__init__()
assert dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
def forward(self, x, pos, mask=None, transformer_options={}):
B, N, C = x.shape
H = self.num_heads
D = self.head_dim
qkv = self.qkv(x).reshape(B, N, 3, H, D).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q, k = apply_rope(self.q_norm(q), self.k_norm(k), pos[None, None])
x = optimized_attention(q, k, v, H, mask=mask, skip_reshape=True, transformer_options=transformer_options)
return self.proj(x)
class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
super().__init__()
self.norm = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, out_channels, bias=True, dtype=dtype, device=device)
def forward(self, x):
return self.linear(self.norm(x))
class PatchTokenEmbedder(nn.Module):
"""Linear projection used both for patchified-image tokens and text-feature tokens."""
def __init__(self, in_chans, embed_dim, use_norm=False, bias=True, dtype=None, device=None, operations=None):
super().__init__()
self.proj = operations.Linear(in_chans, embed_dim, bias=bias, dtype=dtype, device=device)
self.norm = operations.RMSNorm(embed_dim, eps=1e-6, dtype=dtype, device=device) if use_norm else nn.Identity()
def forward(self, x):
return self.norm(self.proj(x))
class PixelTokenEmbedder(nn.Module):
"""Pixel-level embedder: lifts each RGB pixel to hidden_size and packs into per-patch sequences."""
def __init__(self, in_channels, hidden_size_output, dtype=None, device=None, operations=None):
super().__init__()
self.in_channels = in_channels
self.hidden_size_output = hidden_size_output
self.proj = operations.Linear(self.in_channels, self.hidden_size_output, bias=True, dtype=dtype, device=device)
def forward(self, inputs, patch_size):
B, _, H, W = inputs.shape
Hs, Ws = H // patch_size, W // patch_size
P2 = patch_size * patch_size
x = inputs.permute(0, 2, 3, 1).contiguous()
x = self.proj(x)
pos_full = get_2d_sincos_pos_embed(self.hidden_size_output, H, W, device=x.device, dtype=x.dtype).view(H, W, self.hidden_size_output)
x = x + pos_full.unsqueeze(0)
x = x.view(B, Hs, patch_size, Ws, patch_size, self.hidden_size_output)
return x.permute(0, 1, 3, 2, 4, 5).reshape(B * Hs * Ws, P2, self.hidden_size_output)
class PiTBlock(nn.Module):
"""Pixel-level transformer block.
Compresses each patch's P^2 pixel tokens → 1 attention token via a linear,
runs global self-attention across patches with 2D RoPE, then expands back to P^2 tokens.
Conditioning is per-pixel adaLN from the patch-level features.
"""
def __init__(self, pixel_hidden_size, patch_hidden_size, patch_size, num_heads, mlp_ratio=4.0,
attn_hidden_size=None, attn_num_heads=None, dtype=None, device=None, operations=None, mlp_chunks=1):
super().__init__()
self.pixel_dim = pixel_hidden_size
self.context_dim = patch_hidden_size
self.attn_dim = attn_hidden_size if attn_hidden_size is not None else patch_hidden_size
self.num_heads = attn_num_heads if attn_num_heads is not None else num_heads
assert self.attn_dim % self.num_heads == 0
p2 = patch_size * patch_size
self.compress_to_attn = operations.Linear(p2 * self.pixel_dim, self.attn_dim, bias=True, dtype=dtype, device=device)
self.expand_from_attn = operations.Linear(self.attn_dim, p2 * self.pixel_dim, bias=True, dtype=dtype, device=device)
self.norm1 = operations.RMSNorm(self.pixel_dim, eps=1e-6, dtype=dtype, device=device)
self.attn = RotaryAttention(self.attn_dim, num_heads=self.num_heads, qkv_bias=False, dtype=dtype, device=device, operations=operations)
self.norm2 = operations.RMSNorm(self.pixel_dim, eps=1e-6, dtype=dtype, device=device)
self.mlp = Mlp(self.pixel_dim, hidden_features=int(self.pixel_dim * mlp_ratio), dtype=dtype, device=device, operations=operations)
self.adaLN_modulation_msa = operations.Linear(self.context_dim, 3 * self.pixel_dim * p2, bias=True, dtype=dtype, device=device)
self.adaLN_modulation_mlp = operations.Linear(self.context_dim, 3 * self.pixel_dim * p2, bias=True, dtype=dtype, device=device)
self._rope_fn = precompute_freqs_cis_2d
self.mlp_chunks = max(1, int(mlp_chunks))
def _fetch_pos(self, height, width, device, dtype, **rope_opts):
return self._rope_fn(self.attn_dim // self.num_heads, height, width, device=device, dtype=dtype, **rope_opts)
def forward(self, x, s_cond, image_height, image_width, patch_size, mask=None, transformer_options={}):
BL, P2, _ = x.shape
Hs, Ws = image_height // patch_size, image_width // patch_size
L = Hs * Ws
B = BL // L
# Attention path uses only msa params; compute, use, free before mlp params allocate.
msa_params = self.adaLN_modulation_msa(s_cond).view(BL, P2, 3 * self.pixel_dim)
shift_msa, scale_msa, gate_msa = msa_params.chunk(3, dim=-1)
x_norm = apply_adaln_(self.norm1(x), shift_msa, scale_msa)
x_flat = x_norm.view(BL, P2 * self.pixel_dim)
x_comp = self.compress_to_attn(x_flat).view(B, L, self.attn_dim)
pos_comp = self._fetch_pos(Hs, Ws, x.device, x.dtype, **(transformer_options.get("rope_options") or {}))
attn_out = self.attn(x_comp, pos_comp, mask=mask, transformer_options=transformer_options)
attn_flat = self.expand_from_attn(attn_out.view(B * L, self.attn_dim))
attn_exp = attn_flat.view(BL, P2, self.pixel_dim)
x = torch.addcmul(x, gate_msa, attn_exp)
del msa_params, shift_msa, scale_msa, gate_msa
mlp_params = self.adaLN_modulation_mlp(s_cond).view(BL, P2, 3 * self.pixel_dim)
shift_mlp, scale_mlp, gate_mlp = mlp_params.chunk(3, dim=-1)
gate_mlp = gate_mlp.contiguous() # detach from mlp_params so the del below frees shift+scale storage before the MLP
mlp_input = apply_adaln_(self.norm2(x), shift_mlp, scale_mlp)
del mlp_params, shift_mlp, scale_mlp
# MLP in chunks since the peak memory usage is huge here
chunk_size = (BL + self.mlp_chunks - 1) // self.mlp_chunks
for s in range(0, BL, chunk_size):
e = min(s + chunk_size, BL)
x[s:e].addcmul_(gate_mlp[s:e], self.mlp(mlp_input[s:e]))
return x

227
comfy/ldm/pixeldit/pid.py Normal file
View File

@ -0,0 +1,227 @@
"""PiD — Pixel Diffusion Decoder. Decodes a Flux/SD3/Flux2/Z-Image latent
directly to a 4x-upscaled image in 4 distilled flow-matching steps. PixDiT_T2I
body + LQ projection branch injected before each MMDiT patch block.
"""
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from .model import PixDiT_T2I
from .modules import precompute_freqs_cis_2d
class SigmaAwareGatePerTokenPerDim(nn.Module):
"""gate = sigmoid(content_proj(cat[x, lq]) - exp(log_alpha) * sigma); out = x + gate * lq.
Trained init gives ~0.88 gate at sigma=0, ~0.05 at sigma=1.
"""
def __init__(self, dim: int, dtype=None, device=None, operations=None):
super().__init__()
self.content_proj = operations.Linear(dim * 2, dim, dtype=dtype, device=device)
self.log_alpha = nn.Parameter(torch.empty((), dtype=dtype, device=device))
def forward(self, x: torch.Tensor, lq: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
content_logit = self.content_proj(torch.cat([x, lq], dim=-1))
# log_alpha is a raw nn.Parameter -> doesn't auto-cast under dynamic VRAM.
log_alpha = self.log_alpha.to(device=x.device, dtype=torch.float32)
sigma_offset = -log_alpha.exp() * sigma.float().view(-1, 1, 1)
gate = torch.sigmoid(content_logit + sigma_offset)
return x + (gate * lq).to(x.dtype)
class ResBlock(nn.Module):
"""Pre-activation ResNet block: GN -> SiLU -> Conv -> GN -> SiLU -> Conv + skip."""
def __init__(self, channels: int, num_groups: int = 4, dtype=None, device=None, operations=None):
super().__init__()
self.block = nn.Sequential(
operations.GroupNorm(num_groups, channels, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(channels, channels, kernel_size=3, padding=1, dtype=dtype, device=device),
operations.GroupNorm(num_groups, channels, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(channels, channels, kernel_size=3, padding=1, dtype=dtype, device=device),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.block(x)
class LQProjection2D(nn.Module):
"""LQ latent -> per-block patch-aligned features for controlnet-style injection."""
def __init__(
self,
latent_channels: int,
hidden_dim: int = 512,
out_dim: int = 1536,
patch_size: int = 16,
sr_scale: int = 4,
latent_spatial_down_factor: int = 8,
num_res_blocks: int = 4,
num_outputs: int = 7,
interval: int = 2,
dtype=None, device=None, operations=None,
):
super().__init__()
self.latent_channels = latent_channels
self.hidden_dim = hidden_dim
self.out_dim = out_dim
self.patch_size = patch_size
self.sr_scale = sr_scale
self.latent_spatial_down_factor = latent_spatial_down_factor
self.num_outputs = num_outputs
self.interval = interval
z_to_patch_ratio = (sr_scale * latent_spatial_down_factor) / patch_size
self.z_to_patch_ratio = z_to_patch_ratio
if z_to_patch_ratio >= 1:
self.latent_fold_factor = 0
latent_proj_in_ch = latent_channels
else:
fold_factor = int(1 / z_to_patch_ratio)
assert fold_factor * z_to_patch_ratio == 1.0
self.latent_fold_factor = fold_factor
latent_proj_in_ch = latent_channels * fold_factor * fold_factor
layers = [
operations.Conv2d(latent_proj_in_ch, hidden_dim, kernel_size=3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1, dtype=dtype, device=device),
]
for _ in range(num_res_blocks):
layers.append(ResBlock(hidden_dim, dtype=dtype, device=device, operations=operations))
self.latent_proj = nn.Sequential(*layers)
self.output_heads = nn.ModuleList(
[operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device) for _ in range(num_outputs)]
)
self.gate_modules = nn.ModuleList(
[SigmaAwareGatePerTokenPerDim(out_dim, dtype=dtype, device=device, operations=operations)
for _ in range(num_outputs)]
)
def is_gate_active(self, block_idx: int) -> bool:
return block_idx % self.interval == 0
def output_index(self, block_idx: int) -> int:
return block_idx // self.interval
def gate(self, x: torch.Tensor, lq_feature: torch.Tensor, sigma: torch.Tensor, out_idx: int) -> torch.Tensor:
return self.gate_modules[out_idx](x, lq_feature, sigma)
def _align_latent_to_patch_grid(self, lq_latent: torch.Tensor, pH: int, pW: int) -> torch.Tensor:
B, z_dim = lq_latent.shape[:2]
if self.z_to_patch_ratio >= 1:
if lq_latent.shape[2] != pH or lq_latent.shape[3] != pW:
z_aligned = F.interpolate(lq_latent, size=(pH, pW), mode="nearest")
else:
z_aligned = lq_latent
else:
f = self.latent_fold_factor
zH_expected, zW_expected = pH * f, pW * f
if lq_latent.shape[2] != zH_expected or lq_latent.shape[3] != zW_expected:
lq_latent = F.interpolate(lq_latent, size=(zH_expected, zW_expected), mode="nearest")
z_aligned = lq_latent.reshape(B, z_dim, pH, f, pW, f).permute(0, 1, 3, 5, 2, 4)
z_aligned = z_aligned.reshape(B, z_dim * f * f, pH, pW)
return self.latent_proj(z_aligned)
def forward(self, lq_latent: torch.Tensor, target_pH: int, target_pW: int) -> List[torch.Tensor]:
feat = self._align_latent_to_patch_grid(lq_latent, target_pH, target_pW)
B, C, H, W = feat.shape
tokens = feat.permute(0, 2, 3, 1).contiguous().view(B, H * W, C)
return [head(tokens) for head in self.output_heads]
class PidNet(PixDiT_T2I):
"""PixDiT_T2I + LQ injection (one sigma-gated feature inserted before each patch block)."""
def __init__(
self,
lq_latent_channels: int = 16,
lq_hidden_dim: int = 512,
lq_num_res_blocks: int = 4,
lq_interval: int = 2,
sr_scale: int = 4,
latent_spatial_down_factor: int = 8,
rope_ref_h: int = 1024, # NTK ref resolution in PIXEL units: 1024px / patch=16 -> grid_ref=64.
rope_ref_w: int = 1024,
image_model=None,
dtype=None, device=None, operations=None,
**pixdit_kwargs,
):
super().__init__(dtype=dtype, device=device, operations=operations, **pixdit_kwargs)
self.rope_ref_grid_h = rope_ref_h // self.patch_size
self.rope_ref_grid_w = rope_ref_w // self.patch_size
# Parent's PiTBlocks were built with plain RoPE — swap in NTK-aware.
def _pit_rope_fn(head_dim, h, w, device=None, dtype=torch.float32, **rope_opts):
return precompute_freqs_cis_2d(head_dim, h, w, ref_grid_h=self.rope_ref_grid_h, ref_grid_w=self.rope_ref_grid_w, device=device, dtype=dtype, **rope_opts)
for blk in self.pixel_blocks:
blk._rope_fn = _pit_rope_fn
num_lq_outputs = (self.patch_depth + lq_interval - 1) // lq_interval
self.lq_proj = LQProjection2D(
latent_channels=lq_latent_channels,
hidden_dim=lq_hidden_dim,
out_dim=self.hidden_size,
patch_size=self.patch_size,
sr_scale=sr_scale,
latent_spatial_down_factor=latent_spatial_down_factor,
num_res_blocks=lq_num_res_blocks,
num_outputs=num_lq_outputs,
interval=lq_interval,
dtype=dtype,
device=device,
operations=operations,
)
def _fetch_patch_pos(self, height, width, device, dtype, **rope_opts):
return precompute_freqs_cis_2d(
self.hidden_size // self.num_groups,
height, width,
ref_grid_h=self.rope_ref_grid_h, ref_grid_w=self.rope_ref_grid_w,
device=device, dtype=dtype, **rope_opts,
)
def _pre_patch_block(self, s, i, pid_lq_features, pid_degrade_sigma, **kwargs):
if not self.lq_proj.is_gate_active(i):
return s
out_idx = self.lq_proj.output_index(i)
if out_idx >= len(pid_lq_features):
return s
return self.lq_proj.gate(s, pid_lq_features[out_idx], pid_degrade_sigma, out_idx)
def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, lq_latent=None, degrade_sigma=None, **kwargs):
if lq_latent is None:
raise ValueError("PidNet requires lq_latent — attach via PiDConditioning")
expected_c = self.lq_proj.latent_channels
if lq_latent.shape[1] != expected_c:
raise ValueError(
f"Input latent has {lq_latent.shape[1]} channels, this model variant expects {expected_c}. "
f"Flux1/SD3 = 16 channels, Flux2 = 128 channels."
)
B = x.shape[0]
# Match the backbone's pad_to_patch_size (round up) so the LQ grid lines up with the patch stream.
Hs = -(-x.shape[2] // self.patch_size)
Ws = -(-x.shape[3] // self.patch_size)
degrade_sigma = degrade_sigma.to(device=x.device, dtype=torch.float32).reshape(-1)
if degrade_sigma.numel() == 1 and B > 1:
degrade_sigma = degrade_sigma.expand(B).contiguous()
lq_features = self.lq_proj(lq_latent=lq_latent.to(x), target_pH=Hs, target_pW=Ws)
return super()._forward(
x, timesteps,
context=context, attention_mask=attention_mask,
transformer_options=transformer_options,
pid_lq_features=lq_features,
pid_degrade_sigma=degrade_sigma,
**kwargs,
)

View File

@ -51,15 +51,6 @@ class FeedForward(nn.Module):
return hidden_states return hidden_states
def apply_rotary_emb(x, freqs_cis):
if x.shape[1] == 0:
return x
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
return t_out.reshape(*x.shape)
class QwenTimestepProjEmbeddings(nn.Module): class QwenTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None): def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None):
super().__init__() super().__init__()

View File

@ -16,7 +16,6 @@
along with this program. If not, see <https://www.gnu.org/licenses/>. along with this program. If not, see <https://www.gnu.org/licenses/>.
""" """
from __future__ import annotations
import comfy.memory_management import comfy.memory_management
import comfy.utils import comfy.utils
import comfy.model_management import comfy.model_management
@ -484,16 +483,23 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
return weight return weight
def prefetch_prepared_value(value, allocate_buffer, stream): def prefetch_prepared_value(value, counter, destination, stream, copy):
if isinstance(value, torch.Tensor): if isinstance(value, torch.Tensor):
dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value)) size = comfy.memory_management.vram_aligned_size(value)
offset = counter[0]
counter[0] += size
if destination is None:
return value
dest = destination[offset:offset + size]
if copy:
comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream) comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream)
return comfy.memory_management.interpret_gathered_like([value], dest)[0] return comfy.memory_management.interpret_gathered_like([value], dest)[0]
elif isinstance(value, weight_adapter.WeightAdapterBase): elif isinstance(value, weight_adapter.WeightAdapterBase):
return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream)) return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, counter, destination, stream, copy))
elif isinstance(value, tuple): elif isinstance(value, tuple):
return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value) return tuple(prefetch_prepared_value(item, counter, destination, stream, copy) for item in value)
elif isinstance(value, list): elif isinstance(value, list):
return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value] return [prefetch_prepared_value(item, counter, destination, stream, copy) for item in value]
return value return value

View File

@ -1,6 +1,5 @@
import math import math
import ctypes import ctypes
import threading
import dataclasses import dataclasses
import torch import torch
from typing import NamedTuple from typing import NamedTuple
@ -10,12 +9,12 @@ from comfy.quant_ops import QuantizedTensor
class TensorFileSlice(NamedTuple): class TensorFileSlice(NamedTuple):
file_ref: object file_ref: object
thread_id: int lock: object
offset: int offset: int
size: int size: int
def read_tensor_file_slice_into(tensor, destination): def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=None):
if isinstance(tensor, QuantizedTensor): if isinstance(tensor, QuantizedTensor):
if not isinstance(destination, QuantizedTensor): if not isinstance(destination, QuantizedTensor):
@ -23,12 +22,17 @@ def read_tensor_file_slice_into(tensor, destination):
if tensor._layout_cls != destination._layout_cls: if tensor._layout_cls != destination._layout_cls:
return False return False
if not read_tensor_file_slice_into(tensor._qdata, destination._qdata): if not read_tensor_file_slice_into(tensor._qdata, destination._qdata, stream=stream,
destination2=(destination2._qdata if destination2 is not None else None)):
return False return False
dst_orig_dtype = destination._params.orig_dtype dst_orig_dtype = destination._params.orig_dtype
destination._params.copy_from(tensor._params, non_blocking=False) destination._params.copy_from(tensor._params, non_blocking=False)
destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype) destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
if destination2 is not None:
dst_orig_dtype = destination2._params.orig_dtype
destination2._params.copy_from(destination._params, non_blocking=True)
destination2._params = dataclasses.replace(destination2._params, orig_dtype=dst_orig_dtype)
return True return True
info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None) info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None)
@ -38,7 +42,6 @@ def read_tensor_file_slice_into(tensor, destination):
file_obj = info.file_ref file_obj = info.file_ref
if (destination.device.type != "cpu" if (destination.device.type != "cpu"
or file_obj is None or file_obj is None
or threading.get_ident() != info.thread_id
or destination.numel() * destination.element_size() < info.size or destination.numel() * destination.element_size() < info.size
or tensor.numel() * tensor.element_size() != info.size or tensor.numel() * tensor.element_size() != info.size
or tensor.storage_offset() != 0 or tensor.storage_offset() != 0
@ -48,10 +51,23 @@ def read_tensor_file_slice_into(tensor, destination):
if info.size == 0: if info.size == 0:
return True return True
hostbuf = getattr(destination.untyped_storage(), "_comfy_hostbuf", None)
if hostbuf is not None:
stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0
device_ptr = destination2.data_ptr() if destination2 is not None else 0
with info.lock:
hostbuf.read_file_slice(file_obj, info.offset, info.size,
offset=destination.data_ptr() - hostbuf.get_raw_address(),
stream=stream_ptr,
device_ptr=device_ptr,
device=None if destination2 is None else destination2.device.index)
return True
buf_type = ctypes.c_ubyte * info.size buf_type = ctypes.c_ubyte * info.size
view = memoryview(buf_type.from_address(destination.data_ptr())) view = memoryview(buf_type.from_address(destination.data_ptr()))
try: try:
with info.lock:
file_obj.seek(info.offset) file_obj.seek(info.offset)
done = 0 done = 0
while done < info.size: while done < info.size:
@ -151,7 +167,7 @@ def set_ram_cache_release_state(callback, headroom):
extra_ram_release_callback = callback extra_ram_release_callback = callback
RAM_CACHE_HEADROOM = max(0, int(headroom)) RAM_CACHE_HEADROOM = max(0, int(headroom))
def extra_ram_release(target): def extra_ram_release(target, free_active=False):
if extra_ram_release_callback is None: if extra_ram_release_callback is None:
return 0 return 0
return extra_ram_release_callback(target) return extra_ram_release_callback(target, free_active=free_active)

View File

@ -35,6 +35,7 @@ import comfy.ldm.hydit.models
import comfy.ldm.audio.dit import comfy.ldm.audio.dit
import comfy.ldm.audio.embedders import comfy.ldm.audio.embedders
import comfy.ldm.flux.model import comfy.ldm.flux.model
import comfy.ldm.lens.model
import comfy.ldm.lightricks.model import comfy.ldm.lightricks.model
import comfy.ldm.hunyuan_video.model import comfy.ldm.hunyuan_video.model
import comfy.ldm.cosmos.model import comfy.ldm.cosmos.model
@ -48,6 +49,8 @@ import comfy.ldm.hunyuan3d.model
import comfy.ldm.hidream.model import comfy.ldm.hidream.model
import comfy.ldm.chroma.model import comfy.ldm.chroma.model
import comfy.ldm.chroma_radiance.model import comfy.ldm.chroma_radiance.model
import comfy.ldm.pixeldit.model
import comfy.ldm.pixeldit.pid
import comfy.ldm.ace.model import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2 import comfy.ldm.omnigen.omnigen2
import comfy.ldm.qwen_image.model import comfy.ldm.qwen_image.model
@ -1058,6 +1061,27 @@ class Flux2(Flux):
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out return out
class Lens(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(
model_config, model_type, device=device,
unet_model=comfy.ldm.lens.model.LensTransformer2DModel,
)
def encode_adm(self, **kwargs):
return None # Lens has no pooled/ADM conditioning.
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
return out
class GenmoMochi(BaseModel): class GenmoMochi(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint) super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint)
@ -1375,6 +1399,53 @@ class ZImagePixelSpace(Lumina2):
BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace) BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace)
self.memory_usage_factor_conds = ("ref_latents",) self.memory_usage_factor_conds = ("ref_latents",)
class PixelDiTT2I(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device,
unet_model=comfy.ldm.pixeldit.model.PixDiT_T2I)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out["attention_mask"] = comfy.conds.CONDRegular(attention_mask)
return out
class PiD(PixelDiTT2I):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
BaseModel.__init__(self, model_config, model_type, device=device,
unet_model=comfy.ldm.pixeldit.pid.PidNet)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
lq_latent = kwargs.get("lq_latent", None)
if lq_latent is not None:
out["lq_latent"] = comfy.conds.CONDRegular(lq_latent)
degrade_sigma = kwargs.get("degrade_sigma", None)
if degrade_sigma is not None:
out["degrade_sigma"] = comfy.conds.CONDRegular(degrade_sigma)
return out
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
if cond_key == "lq_latent" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
lq = cond_value.cond
dim = window.dim
if dim >= lq.ndim:
return None
lq_proj = self.diffusion_model.lq_proj
ratio = lq_proj.sr_scale * lq_proj.latent_spatial_down_factor
# Map x window indices -> lq indices (deduplicated, sorted, in-bounds).
lq_size = lq.size(dim)
lq_indices = sorted({i // ratio for i in window.index_list if 0 <= i // ratio < lq_size})
if not lq_indices:
return None
idx = tuple([slice(None)] * dim + [lq_indices])
return cond_value._copy_with(lq[idx].to(device))
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
class WAN21(BaseModel): class WAN21(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)

View File

@ -463,6 +463,23 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable" dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
return dit_config return dit_config
# PiD (Pixel Diffusion Decoder). Must check BEFORE plain PixelDiT_T2I.
_lq_w_key = '{}lq_proj.latent_proj.0.weight'.format(key_prefix)
if _lq_w_key in state_dict_keys:
in_ch = int(state_dict[_lq_w_key].shape[1])
_gate_prefix = '{}lq_proj.gate_modules.'.format(key_prefix)
num_gates = len({k[len(_gate_prefix):].split('.')[0]
for k in state_dict_keys if k.startswith(_gate_prefix)})
dit_config = {"image_model": "pid",
"lq_latent_channels": in_ch,
"latent_spatial_down_factor": 16 if in_ch >= 64 else 8}
if num_gates > 0:
dit_config["lq_interval"] = (14 + num_gates - 1) // num_gates
return dit_config
if '{}core.pixel_embedder.proj.weight'.format(key_prefix) in state_dict_keys: # PixelDiT T2I
return {"image_model": "pixeldit_t2i"}
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys and '{}noise_refiner.0.attention.k_norm.weight'.format(key_prefix) in state_dict_keys: # Lumina 2 if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys and '{}noise_refiner.0.attention.k_norm.weight'.format(key_prefix) in state_dict_keys: # Lumina 2
dit_config = {} dit_config = {}
dit_config["image_model"] = "lumina2" dit_config["image_model"] = "lumina2"
@ -755,6 +772,30 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["timestep_scale"] = 1000.0 dit_config["timestep_scale"] = 1000.0
return dit_config return dit_config
if '{}transformer_blocks.0.attn.norm_added_q.weight'.format(key_prefix) in state_dict_keys \
and '{}transformer_blocks.0.img_mlp.w1.weight'.format(key_prefix) in state_dict_keys: # Lens
img_in_w = state_dict['{}img_in.weight'.format(key_prefix)]
proj_out_w = state_dict['{}proj_out.weight'.format(key_prefix)]
multi_layer = '{}txt_norm.0.weight'.format(key_prefix) in state_dict_keys
if multi_layer:
enc_hidden_dim = state_dict['{}txt_norm.0.weight'.format(key_prefix)].shape[0]
# Indices are TE-side; the DiT just consumes L layers in order.
selected_layer_index = tuple(range(count_blocks(state_dict_keys, '{}txt_norm.'.format(key_prefix) + '{}.')))
else:
enc_hidden_dim = state_dict['{}txt_norm.weight'.format(key_prefix)].shape[0]
selected_layer_index = (0,)
return {
"image_model": "lens",
"in_channels": img_in_w.shape[1],
"out_channels": proj_out_w.shape[0] // 4, # patch_size ** 2 (=2² default)
"num_layers": count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.'),
"num_attention_heads": img_in_w.shape[0] // 64, # // attention_head_dim default
"enc_hidden_dim": enc_hidden_dim,
"multi_layer_encoder_feature": multi_layer,
"selected_layer_index": selected_layer_index,
}
if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image
dit_config = {} dit_config = {}
dit_config["image_model"] = "qwen_image" dit_config["image_model"] = "qwen_image"

View File

@ -15,6 +15,7 @@
You should have received a copy of the GNU General Public License You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>. along with this program. If not, see <https://www.gnu.org/licenses/>.
""" """
from __future__ import annotations
import psutil import psutil
import logging import logging
@ -27,12 +28,18 @@ import platform
import weakref import weakref
import gc import gc
import os import os
from contextlib import nullcontext from contextlib import contextmanager, nullcontext
import comfy.memory_management import comfy.memory_management
import comfy.utils import comfy.utils
import comfy.quant_ops import comfy.quant_ops
import comfy_aimdo.host_buffer
import comfy_aimdo.vram_buffer import comfy_aimdo.vram_buffer
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
class VRAMState(Enum): class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram DISABLED = 0 #No vram present: no need to move models to vram
NO_VRAM = 1 #Very low vram: enable all the options to save vram NO_VRAM = 1 #Very low vram: enable all the options to save vram
@ -203,6 +210,107 @@ def get_torch_device():
else: else:
return torch.device(torch.cuda.current_device()) return torch.device(torch.cuda.current_device())
def get_all_torch_devices(exclude_current=False):
global cpu_state
devices = []
if cpu_state == CPUState.GPU:
# NVIDIA + AMD/ROCm both expose their GPUs through torch.cuda.*;
# without the AMD arm, single-GPU ROCm users get an empty list
# which silently turns unload_all_models() into a no-op.
if is_nvidia() or is_amd():
for i in range(torch.cuda.device_count()):
devices.append(torch.device("cuda", i))
elif is_intel_xpu():
for i in range(torch.xpu.device_count()):
devices.append(torch.device("xpu", i))
elif is_ascend_npu():
for i in range(torch.npu.device_count()):
devices.append(torch.device("npu", i))
elif is_mlu():
for i in range(torch.mlu.device_count()):
devices.append(torch.device("mlu", i))
else:
# Fallback for unhandled GPU backends (e.g. DirectML): at least
# report the current device so callers like unload_all_models()
# do not silently no-op.
devices.append(get_torch_device())
else:
devices.append(get_torch_device())
if exclude_current:
current = get_torch_device()
if current in devices:
devices.remove(current)
return devices
def get_gpu_device_options():
"""Return list of device option strings for node widgets.
Always includes "default" and "cpu". When multiple GPUs are present,
adds "gpu:0", "gpu:1", etc. (vendor-agnostic labels).
"""
options = ["default", "cpu"]
devices = get_all_torch_devices()
if len(devices) > 1:
for i in range(len(devices)):
options.append(f"gpu:{i}")
return options
def get_gpu_device_options_no_cpu():
"""Variant of get_gpu_device_options that omits "cpu".
Intended for components like the VAE selector where running on CPU
is impractical and should not be offered as a choice.
"""
return [o for o in get_gpu_device_options() if o != "cpu"]
def resolve_gpu_device_option(option: str):
"""Resolve a device option string to a torch.device.
Returns None for "default" (let the caller use its normal default).
Returns torch.device("cpu") for "cpu".
For "gpu:N", returns the Nth torch device. Returns None if the
index is out of range, the option string is malformed, or
unrecognized (callers are expected to log their own context-rich
message before falling back to the default device).
"""
if option is None or option == "default":
return None
if option == "cpu":
return torch.device("cpu")
if option.startswith("gpu:"):
try:
idx = int(option[4:])
except ValueError:
return None
devices = get_all_torch_devices()
if 0 <= idx < len(devices):
return devices[idx]
return None
@contextmanager
def cuda_device_context(device):
"""Context manager that sets torch.cuda.current_device to match *device*.
Used when running operations on a non-default CUDA device so that custom
CUDA kernels (e.g. comfy_kitchen fp8 quantization) pick up the correct
device index. The previous device is restored on exit.
No-op when *device* is not CUDA, has no explicit index, or already matches
the current device.
"""
prev = None
if device.type == "cuda" and device.index is not None:
prev = torch.cuda.current_device()
if prev != device.index:
torch.cuda.set_device(device)
else:
prev = None
try:
yield
finally:
if prev is not None:
torch.cuda.set_device(prev)
def get_total_memory(dev=None, torch_total_too=False): def get_total_memory(dev=None, torch_total_too=False):
global directml_enabled global directml_enabled
if dev is None: if dev is None:
@ -491,9 +599,21 @@ try:
logging.info("Device: {}".format(get_torch_device_name(get_torch_device()))) logging.info("Device: {}".format(get_torch_device_name(get_torch_device())))
except: except:
logging.warning("Could not pick default device.") logging.warning("Could not pick default device.")
try:
for device in get_all_torch_devices(exclude_current=True):
logging.info("Device: {}".format(get_torch_device_name(device)))
except:
pass
current_loaded_models: list[LoadedModel] = []
current_loaded_models = [] DIRTY_MMAPS = set()
PIN_PRESSURE_HYSTERESIS = 256 * 1024 * 1024
#Freeing registerables on pressure does imply a GPU sync, so go big on
#the hysteresis so each expensive sync gives us back a good chunk.
REGISTERABLE_PIN_HYSTERESIS = 2048 * 1024 * 1024
def module_size(module): def module_size(module):
module_mem = 0 module_mem = 0
@ -503,30 +623,49 @@ def module_size(module):
module_mem += t.nbytes module_mem += t.nbytes
return module_mem return module_mem
def module_mmap_residency(module, free=False): def mark_mmap_dirty(storage):
mmap_touched_mem = 0 mmap_refs = getattr(storage, "_comfy_tensor_mmap_refs", None)
module_mem = 0 if mmap_refs is not None:
bounced_mmaps = set() DIRTY_MMAPS.add(mmap_refs[0])
sd = module.state_dict()
for k in sd: def free_pins(size, evict_active=False):
t = sd[k] freed_total = 0
module_mem += t.nbytes for loaded_model in reversed(current_loaded_models):
storage = t._qdata.untyped_storage() if isinstance(t, comfy.quant_ops.QuantizedTensor) else t.untyped_storage() if size <= 0:
if not getattr(storage, "_comfy_tensor_mmap_touched", False): return freed_total
continue model = loaded_model.model
mmap_touched_mem += t.nbytes if model is not None and model.is_dynamic() and (evict_active or not model.model.dynamic_pins[model.load_device]["active"]):
if not free: freed = model.partially_unload_ram(size)
continue freed_total += freed
storage._comfy_tensor_mmap_touched = False size -= freed
mmap_obj = storage._comfy_tensor_mmap_refs[0] return freed_total
if mmap_obj in bounced_mmaps:
continue def ensure_pin_budget(size, evict_active=False):
mmap_obj.bounce() shortfall = size + comfy.memory_management.RAM_CACHE_HEADROOM / 2 - psutil.virtual_memory().available
bounced_mmaps.add(mmap_obj) if shortfall <= 0:
return mmap_touched_mem, module_mem return True
to_free = shortfall + PIN_PRESSURE_HYSTERESIS
return free_pins(to_free, evict_active=evict_active) >= shortfall
def ensure_pin_registerable(size, evict_active=False):
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
if MAX_PINNED_MEMORY <= 0:
return False
if shortfall <= 0:
return True
shortfall += REGISTERABLE_PIN_HYSTERESIS
for loaded_model in reversed(current_loaded_models):
model = loaded_model.model
if model is not None and model.is_dynamic() and (evict_active or not model.model.dynamic_pins[model.load_device]["active"]):
shortfall -= model.unregister_inactive_pins(shortfall)
if shortfall <= 0:
return True
return shortfall <= REGISTERABLE_PIN_HYSTERESIS
class LoadedModel: class LoadedModel:
def __init__(self, model): def __init__(self, model: ModelPatcher):
self._set_model(model) self._set_model(model)
self.device = model.load_device self.device = model.load_device
self.real_model = None self.real_model = None
@ -534,7 +673,7 @@ class LoadedModel:
self.model_finalizer = None self.model_finalizer = None
self._patcher_finalizer = None self._patcher_finalizer = None
def _set_model(self, model): def _set_model(self, model: ModelPatcher):
self._model = weakref.ref(model) self._model = weakref.ref(model)
if model.parent is not None: if model.parent is not None:
self._parent_model = weakref.ref(model.parent) self._parent_model = weakref.ref(model.parent)
@ -545,6 +684,7 @@ class LoadedModel:
model = self._parent_model() model = self._parent_model()
if model is not None: if model is not None:
self._set_model(model) self._set_model(model)
self.device = model.load_device
@property @property
def model(self): def model(self):
@ -553,9 +693,6 @@ class LoadedModel:
def model_memory(self): def model_memory(self):
return self.model.model_size() return self.model.model_size()
def model_mmap_residency(self, free=False):
return self.model.model_mmap_residency(free=free)
def model_loaded_memory(self): def model_loaded_memory(self):
return self.model.loaded_size() return self.model.loaded_size()
@ -635,15 +772,9 @@ WINDOWS = any(platform.win32_ver())
EXTRA_RESERVED_VRAM = 400 * 1024 * 1024 EXTRA_RESERVED_VRAM = 400 * 1024 * 1024
if WINDOWS: if WINDOWS:
import comfy.windows
EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue
if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards
EXTRA_RESERVED_VRAM += 100 * 1024 * 1024 EXTRA_RESERVED_VRAM += 100 * 1024 * 1024
def get_free_ram():
return comfy.windows.get_free_ram()
else:
def get_free_ram():
return psutil.virtual_memory().available
if args.reserve_vram is not None: if args.reserve_vram is not None:
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024 EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
@ -657,7 +788,6 @@ def minimum_inference_memory():
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0): def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0):
cleanup_models_gc() cleanup_models_gc()
comfy.memory_management.extra_ram_release(max(pins_required, ram_required))
unloaded_model = [] unloaded_model = []
can_unload = [] can_unload = []
unloaded_models = [] unloaded_models = []
@ -673,11 +803,9 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
for x in can_unload_sorted: for x in can_unload_sorted:
i = x[-1] i = x[-1]
memory_to_free = 1e32 memory_to_free = 1e32
pins_to_free = 1e32 if current_loaded_models[i].model.is_dynamic() and (not DISABLE_SMART_MEMORY or device is None):
if not DISABLE_SMART_MEMORY or device is None:
memory_to_free = 0 if device is None else memory_required - get_free_memory(device) memory_to_free = 0 if device is None else memory_required - get_free_memory(device)
pins_to_free = pins_required - get_free_ram() if for_dynamic:
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
#don't actually unload dynamic models for the sake of other dynamic models #don't actually unload dynamic models for the sake of other dynamic models
#as that works on-demand. #as that works on-demand.
memory_required -= current_loaded_models[i].model.loaded_size() memory_required -= current_loaded_models[i].model.loaded_size()
@ -685,18 +813,6 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free): if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
unloaded_model.append(i) unloaded_model.append(i)
if pins_to_free > 0:
logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}")
current_loaded_models[i].model.partially_unload_ram(pins_to_free)
for x in can_unload_sorted:
i = x[-1]
ram_to_free = ram_required - psutil.virtual_memory().available
if ram_to_free <= 0 and i not in unloaded_model:
continue
resident_memory, _ = current_loaded_models[i].model_mmap_residency(free=True)
if resident_memory > 0:
logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}")
for i in sorted(unloaded_model, reverse=True): for i in sorted(unloaded_model, reverse=True):
unloaded_models.append(current_loaded_models.pop(i)) unloaded_models.append(current_loaded_models.pop(i))
@ -762,29 +878,16 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
model_to_unload.model.detach(unpatch_all=False) model_to_unload.model.detach(unpatch_all=False)
model_to_unload.model_finalizer.detach() model_to_unload.model_finalizer.detach()
total_memory_required = {} total_memory_required = {}
total_pins_required = {}
total_ram_required = {}
for loaded_model in models_to_load: for loaded_model in models_to_load:
device = loaded_model.device device = loaded_model.device
total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device) total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device)
resident_memory, model_memory = loaded_model.model_mmap_residency()
pinned_memory = loaded_model.model.pinned_memory_size()
#FIXME: This can over-free the pins as it budgets to pin the entire model. We should
#make this JIT to keep as much pinned as possible.
pins_required = model_memory - pinned_memory
ram_required = model_memory - resident_memory
total_pins_required[device] = total_pins_required.get(device, 0) + pins_required
total_ram_required[device] = total_ram_required.get(device, 0) + ram_required
for device in total_memory_required: for device in total_memory_required:
if device != torch.device("cpu"): if device != torch.device("cpu"):
free_memory(total_memory_required[device] * 1.1 + extra_mem, free_memory(total_memory_required[device] * 1.1 + extra_mem,
device, device,
for_dynamic=free_for_dynamic, for_dynamic=free_for_dynamic)
pins_required=total_pins_required[device],
ram_required=total_ram_required[device])
for device in total_memory_required: for device in total_memory_required:
if device != torch.device("cpu"): if device != torch.device("cpu"):
@ -1180,6 +1283,7 @@ STREAM_CAST_BUFFERS = {}
LARGEST_CASTED_WEIGHT = (None, 0) LARGEST_CASTED_WEIGHT = (None, 0)
STREAM_AIMDO_CAST_BUFFERS = {} STREAM_AIMDO_CAST_BUFFERS = {}
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
STREAM_PIN_BUFFERS = {}
DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3 DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3
@ -1220,21 +1324,66 @@ def get_aimdo_cast_buffer(offload_stream, device):
if cast_buffer is None: if cast_buffer is None:
cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index) cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index)
STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer
return cast_buffer return cast_buffer
def get_pin_buffer(offload_stream):
pin_buffer = STREAM_PIN_BUFFERS.get(offload_stream, None)
if pin_buffer is None:
pin_buffer = comfy_aimdo.host_buffer.HostBuffer(0, 0, pinned_hostbuf_size(8 * 1024**3), mark_cold=False)
STREAM_PIN_BUFFERS[offload_stream] = pin_buffer
elif offload_stream is not None:
event = getattr(pin_buffer, "_comfy_event", None)
if event is not None:
event.synchronize()
delattr(pin_buffer, "_comfy_event")
return pin_buffer
def resize_pin_buffer(pin_buffer, size):
global TOTAL_PINNED_MEMORY
old_size = pin_buffer.size
if size <= old_size:
return True
growth = size - old_size
comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM)
ensure_pin_budget(growth, evict_active=True)
ensure_pin_registerable(growth, evict_active=True)
try:
pin_buffer.extend(size=size, reallocate=True)
except RuntimeError:
return False
TOTAL_PINNED_MEMORY += pin_buffer.size - old_size
return True
def reset_cast_buffers(): def reset_cast_buffers():
global TOTAL_PINNED_MEMORY
global LARGEST_CASTED_WEIGHT global LARGEST_CASTED_WEIGHT
global LARGEST_AIMDO_CASTED_WEIGHT global LARGEST_AIMDO_CASTED_WEIGHT
LARGEST_CASTED_WEIGHT = (None, 0) LARGEST_CASTED_WEIGHT = (None, 0)
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS): for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS) | set(STREAM_PIN_BUFFERS):
if offload_stream is not None: if offload_stream is not None:
offload_stream.synchronize() offload_stream.synchronize()
synchronize() synchronize()
for mmap_obj in DIRTY_MMAPS:
mmap_obj.bounce()
DIRTY_MMAPS.clear()
for pin_buffer in STREAM_PIN_BUFFERS.values():
TOTAL_PINNED_MEMORY -= pin_buffer.size
TOTAL_PINNED_MEMORY = max(0, TOTAL_PINNED_MEMORY)
for loaded_model in current_loaded_models:
model = loaded_model.model
if model is not None and model.is_dynamic():
model.model.dynamic_pins[model.load_device]["active"] = False
model.partially_unload_ram(1e30, subsets=[ "patches" ])
model.model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, pinned_hostbuf_size(model.model_size())), [], [-1], [0])
STREAM_CAST_BUFFERS.clear() STREAM_CAST_BUFFERS.clear()
STREAM_AIMDO_CAST_BUFFERS.clear() STREAM_AIMDO_CAST_BUFFERS.clear()
STREAM_PIN_BUFFERS.clear()
soft_empty_cache() soft_empty_cache()
def get_offload_stream(device): def get_offload_stream(device):
@ -1280,7 +1429,7 @@ def sync_stream(device, stream):
current_stream(device).wait_stream(stream) current_stream(device).wait_stream(stream)
def cast_to_gathered(tensors, r, non_blocking=False, stream=None): def cast_to_gathered(tensors, r, non_blocking=False, stream=None, r2=None):
wf_context = nullcontext() wf_context = nullcontext()
if stream is not None: if stream is not None:
wf_context = stream wf_context = stream
@ -1288,17 +1437,20 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
wf_context = wf_context.as_context(stream) wf_context = wf_context.as_context(stream)
dest_views = comfy.memory_management.interpret_gathered_like(tensors, r) dest_views = comfy.memory_management.interpret_gathered_like(tensors, r)
dest2_views = comfy.memory_management.interpret_gathered_like(tensors, r2) if r2 is not None else None
with wf_context: with wf_context:
for tensor in tensors: for tensor in tensors:
dest_view = dest_views.pop(0) dest_view = dest_views.pop(0)
dest2_view = dest2_views.pop(0) if dest2_views is not None else None
if tensor is None: if tensor is None:
continue continue
if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view): if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view, stream=stream, destination2=dest2_view):
continue continue
storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage() storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage()
if hasattr(storage, "_comfy_tensor_mmap_touched"): mark_mmap_dirty(storage)
storage._comfy_tensor_mmap_touched = True
dest_view.copy_(tensor, non_blocking=non_blocking) dest_view.copy_(tensor, non_blocking=non_blocking)
if dest2_view is not None:
dest2_view.copy_(dest_view, non_blocking=non_blocking)
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None): def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
@ -1339,14 +1491,18 @@ TOTAL_PINNED_MEMORY = 0
MAX_PINNED_MEMORY = -1 MAX_PINNED_MEMORY = -1
if not args.disable_pinned_memory: if not args.disable_pinned_memory:
if is_nvidia() or is_amd(): if is_nvidia() or is_amd():
ram = get_total_memory(torch.device("cpu"))
if WINDOWS: if WINDOWS:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.40 # Windows limit is apparently 50% MAX_PINNED_MEMORY = ram * 0.40 # Windows limit is apparently 50%
else: else:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.90 MAX_PINNED_MEMORY = ram * 0.90
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024))) logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"]) PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"])
def pinned_hostbuf_size(size):
return max(0, int(min(size, MAX_PINNED_MEMORY) * 2))
def discard_cuda_async_error(): def discard_cuda_async_error():
try: try:
a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device()) a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
@ -1378,8 +1534,8 @@ def pin_memory(tensor):
return False return False
size = tensor.nbytes size = tensor.nbytes
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY: comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM)
return False ensure_pin_registerable(size)
ptr = tensor.data_ptr() ptr = tensor.data_ptr()
if ptr == 0: if ptr == 0:
@ -1416,7 +1572,8 @@ def unpin_memory(tensor):
return False return False
if torch.cuda.cudart().cudaHostUnregister(ptr) == 0: if torch.cuda.cudart().cudaHostUnregister(ptr) == 0:
TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr) size = PINNED_MEMORY.pop(ptr)
TOTAL_PINNED_MEMORY -= size
return True return True
else: else:
logging.warning("Unpin error.") logging.warning("Unpin error.")
@ -1803,7 +1960,34 @@ def soft_empty_cache(force=False):
torch.cuda.ipc_collect() torch.cuda.ipc_collect()
def unload_all_models(): def unload_all_models():
free_memory(1e30, get_torch_device()) for device in get_all_torch_devices():
free_memory(1e30, device)
def unload_model_and_clones(model: ModelPatcher, unload_additional_models=True, all_devices=False):
'Unload only model and its clones - primarily for multigpu cloning purposes.'
initial_keep_loaded: list[LoadedModel] = current_loaded_models.copy()
additional_models = []
if unload_additional_models:
additional_models = model.get_nested_additional_models()
keep_loaded = []
for loaded_model in initial_keep_loaded:
if loaded_model.model is not None:
if model.clone_base_uuid == loaded_model.model.clone_base_uuid:
continue
# check additional models if they are a match
skip = False
for add_model in additional_models:
if add_model.clone_base_uuid == loaded_model.model.clone_base_uuid:
skip = True
break
if skip:
continue
keep_loaded.append(loaded_model)
if not all_devices:
free_memory(1e30, get_torch_device(), keep_loaded)
else:
for device in get_all_torch_devices():
free_memory(1e30, device, keep_loaded)
def debug_memory_summary(): def debug_memory_summary():
if is_amd() or is_nvidia(): if is_amd() or is_nvidia():

View File

@ -35,6 +35,7 @@ import comfy.model_management
import comfy.ops import comfy.ops
import comfy.patcher_extension import comfy.patcher_extension
import comfy.utils import comfy.utils
import comfy_aimdo.host_buffer
from comfy.comfy_types import UnetWrapperFunction from comfy.comfy_types import UnetWrapperFunction
from comfy.quant_ops import QuantizedTensor from comfy.quant_ops import QuantizedTensor
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
@ -77,12 +78,15 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
def create_model_options_clone(orig_model_options: dict): def create_model_options_clone(orig_model_options: dict):
return comfy.patcher_extension.copy_nested_dicts(orig_model_options) return comfy.patcher_extension.copy_nested_dicts(orig_model_options)
def create_hook_patches_clone(orig_hook_patches): def create_hook_patches_clone(orig_hook_patches, copy_tuples=False):
new_hook_patches = {} new_hook_patches = {}
for hook_ref in orig_hook_patches: for hook_ref in orig_hook_patches:
new_hook_patches[hook_ref] = {} new_hook_patches[hook_ref] = {}
for k in orig_hook_patches[hook_ref]: for k in orig_hook_patches[hook_ref]:
new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:] new_hook_patches[hook_ref][k] = orig_hook_patches[hook_ref][k][:]
if copy_tuples:
for i in range(len(new_hook_patches[hook_ref][k])):
new_hook_patches[hook_ref][k][i] = tuple(new_hook_patches[hook_ref][k][i])
return new_hook_patches return new_hook_patches
def wipe_lowvram_weight(m): def wipe_lowvram_weight(m):
@ -117,6 +121,8 @@ def string_to_seed(data):
return comfy.utils.string_to_seed(data) return comfy.utils.string_to_seed(data)
class LowVramPatch: class LowVramPatch:
is_lowvram_patch = True
def __init__(self, key, patches, convert_func=None, set_func=None): def __init__(self, key, patches, convert_func=None, set_func=None):
self.key = key self.key = key
self.patches = patches self.patches = patches
@ -124,11 +130,21 @@ class LowVramPatch:
self.set_func = set_func self.set_func = set_func
self.prepared_patches = None self.prepared_patches = None
def prepare(self, allocate_buffer, stream): def memory_required(self):
self.prepared_patches = [ counter = [0]
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], allocate_buffer, stream), patch[2], patch[3], patch[4]) for patch in self.patches[self.key]:
comfy.lora.prefetch_prepared_value(patch[1], counter, None, None, False)
return counter[0]
def prepare(self, destination, stream, copy=True, commit=True):
counter = [0]
prepared_patches = [
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], counter, destination, stream, copy), patch[2], patch[3], patch[4])
for patch in self.patches[self.key] for patch in self.patches[self.key]
] ]
if commit:
self.prepared_patches = prepared_patches
return prepared_patches
def clear_prepared(self): def clear_prepared(self):
self.prepared_patches = None self.prepared_patches = None
@ -316,7 +332,10 @@ class ModelPatcher:
self.is_clip = False self.is_clip = False
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
self.cached_patcher_init: tuple[Callable, tuple] | None = None self.cached_patcher_init: tuple[Callable, tuple] | tuple[Callable, tuple, int] | None = None
self.is_multigpu_base_clone = False
self.clone_base_uuid = uuid.uuid4()
if not hasattr(self.model, 'model_loaded_weight_memory'): if not hasattr(self.model, 'model_loaded_weight_memory'):
self.model.model_loaded_weight_memory = 0 self.model.model_loaded_weight_memory = 0
@ -341,9 +360,6 @@ class ModelPatcher:
self.size = comfy.model_management.module_size(self.model) self.size = comfy.model_management.module_size(self.model)
return self.size return self.size
def model_mmap_residency(self, free=False):
return comfy.model_management.module_mmap_residency(self.model, free=free)
def loaded_size(self): def loaded_size(self):
return self.model.model_loaded_weight_memory return self.model.model_loaded_weight_memory
@ -356,7 +372,8 @@ class ModelPatcher:
#than pays for CFG. So return everything both torch and Aimdo could give us #than pays for CFG. So return everything both torch and Aimdo could give us
aimdo_mem = 0 aimdo_mem = 0
if comfy.memory_management.aimdo_enabled: if comfy.memory_management.aimdo_enabled:
aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze() aimdo_device = device.index if getattr(device, "type", None) == "cuda" else None
aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze(aimdo_device)
return comfy.model_management.get_free_memory(device) + aimdo_mem return comfy.model_management.get_free_memory(device) + aimdo_mem
def get_clone_model_override(self): def get_clone_model_override(self):
@ -370,6 +387,8 @@ class ModelPatcher:
if self.cached_patcher_init is None: if self.cached_patcher_init is None:
raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.") raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.")
temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True) temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True)
if len(self.cached_patcher_init) > 2:
temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]]
model_override = temp_model_patcher.get_clone_model_override() model_override = temp_model_patcher.get_clone_model_override()
if model_override is None: if model_override is None:
model_override = self.get_clone_model_override() model_override = self.get_clone_model_override()
@ -428,17 +447,111 @@ class ModelPatcher:
n.hook_mode = self.hook_mode n.hook_mode = self.hook_mode
n.cached_patcher_init = self.cached_patcher_init n.cached_patcher_init = self.cached_patcher_init
n.is_multigpu_base_clone = self.is_multigpu_base_clone
n.clone_base_uuid = self.clone_base_uuid
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE): for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
callback(self, n) callback(self, n)
return n return n
def deepclone_multigpu(self, new_load_device=None, models_cache: dict[uuid.UUID,ModelPatcher]=None):
logging.info(f"Creating deepclone of {self.model.__class__.__name__} for {new_load_device if new_load_device else self.load_device}.")
if self.cached_patcher_init is None:
raise RuntimeError(
f"Cannot create multigpu deepclone of {self.model.__class__.__name__}: "
"the loader that produced this model does not support multigpu "
"(cached_patcher_init is not initialized). Use a core loader "
"(CheckpointLoaderSimple, UNETLoader, CLIPLoader/DualCLIPLoader, VAELoader), "
"or have the custom loader register a cached_patcher_init factory."
)
comfy.model_management.unload_model_and_clones(self)
# Produce a freshly-loaded patcher from the loader factory so the multigpu
# clone owns its own untainted model weights (rather than relying on
# copy.deepcopy of an already-patched/already-loaded module).
temp_model_patcher: ModelPatcher | list[ModelPatcher] = self.cached_patcher_init[0](*self.cached_patcher_init[1])
if len(self.cached_patcher_init) > 2:
temp_model_patcher = temp_model_patcher[self.cached_patcher_init[2]]
# Override clone()'s normal "share self.model + share backup containers" with
# the pristine model from temp_model_patcher plus empty backup containers --
# the fresh model has no patches applied, so any deepcopy of self's stale
# backup/object_patches_backup/pinned would just propagate dead state that
# no longer corresponds to anything in n.model.
model_override = (temp_model_patcher.model, ({}, {}, {}, set()))
n = self.clone(model_override=model_override)
# clone() copies hook_backup by reference from self; reset since model is pristine.
n.hook_backup = {}
# set load device, if present
if new_load_device is not None:
n.load_device = new_load_device
# Ensure any per-device bookkeeping (e.g. ModelPatcherDynamic.dynamic_pins)
# has an entry for n.load_device on the freshly-loaded n.model. temp_model_patcher's
# __init__ only registered its own (default) load_device.
if hasattr(n, "register_load_device"):
n.register_load_device(n.load_device)
# multigpu clone should not have multigpu additional_models entry
n.remove_additional_models("multigpu")
# multigpu_clone all stored additional_models; make sure circular references are properly handled
if models_cache is None:
models_cache = {}
for key, model_list in n.additional_models.items():
for i in range(len(model_list)):
add_model = n.additional_models[key][i]
if add_model.clone_base_uuid not in models_cache:
models_cache[add_model.clone_base_uuid] = add_model.deepclone_multigpu(new_load_device=new_load_device, models_cache=models_cache)
n.additional_models[key][i] = models_cache[add_model.clone_base_uuid]
for callback in self.get_all_callbacks(CallbacksMP.ON_DEEPCLONE_MULTIGPU):
callback(self, n)
return n
def match_multigpu_clones(self):
multigpu_models = self.get_additional_models_with_key("multigpu")
if len(multigpu_models) > 0:
new_multigpu_models = []
for mm in multigpu_models:
# clone main model, but bring over relevant props from existing multigpu clone
n = self.clone()
n.load_device = mm.load_device
n.backup = mm.backup
n.object_patches_backup = mm.object_patches_backup
n.hook_backup = mm.hook_backup
n.model = mm.model
n.is_multigpu_base_clone = mm.is_multigpu_base_clone
n.remove_additional_models("multigpu")
orig_additional_models: dict[str, list[ModelPatcher]] = comfy.patcher_extension.copy_nested_dicts(n.additional_models)
n.additional_models = comfy.patcher_extension.copy_nested_dicts(mm.additional_models)
# figure out which additional models are not present in multigpu clone
models_cache = {}
for mm_add_model in mm.get_additional_models():
models_cache[mm_add_model.clone_base_uuid] = mm_add_model
remove_models_uuids = set(list(models_cache.keys()))
for key, model_list in orig_additional_models.items():
for orig_add_model in model_list:
if orig_add_model.clone_base_uuid not in models_cache:
models_cache[orig_add_model.clone_base_uuid] = orig_add_model.deepclone_multigpu(new_load_device=n.load_device, models_cache=models_cache)
existing_list = n.get_additional_models_with_key(key)
existing_list.append(models_cache[orig_add_model.clone_base_uuid])
n.set_additional_models(key, existing_list)
if orig_add_model.clone_base_uuid in remove_models_uuids:
remove_models_uuids.remove(orig_add_model.clone_base_uuid)
# remove duplicate additional models
for key, model_list in n.additional_models.items():
new_model_list = [x for x in model_list if x.clone_base_uuid not in remove_models_uuids]
n.set_additional_models(key, new_model_list)
for callback in self.get_all_callbacks(CallbacksMP.ON_MATCH_MULTIGPU_CLONES):
callback(self, n)
new_multigpu_models.append(n)
self.set_additional_models("multigpu", new_multigpu_models)
def is_clone(self, other): def is_clone(self, other):
if hasattr(other, 'model') and self.model is other.model: if hasattr(other, 'model') and self.model is other.model:
return True return True
return False return False
def clone_has_same_weights(self, clone: 'ModelPatcher'): def clone_has_same_weights(self, clone: ModelPatcher, allow_multigpu=False):
if allow_multigpu:
if self.clone_base_uuid != clone.clone_base_uuid:
return False
else:
if not self.is_clone(clone): if not self.is_clone(clone):
return False return False
@ -1118,8 +1231,12 @@ class ModelPatcher:
# Pinned memory pressure tracking is only implemented for DynamicVram loading # Pinned memory pressure tracking is only implemented for DynamicVram loading
return 0 return 0
def loaded_ram_size(self):
# Loaded RAM pressure tracking is only implemented for DynamicVram loading
return 0
def partially_unload_ram(self, ram_to_unload): def partially_unload_ram(self, ram_to_unload):
pass return 0
def detach(self, unpatch_all=True): def detach(self, unpatch_all=True):
self.eject_model() self.eject_model()
@ -1218,7 +1335,7 @@ class ModelPatcher:
return self.additional_models.get(key, []) return self.additional_models.get(key, [])
def get_additional_models(self): def get_additional_models(self):
all_models = [] all_models: list[ModelPatcher] = []
for models in self.additional_models.values(): for models in self.additional_models.values():
all_models.extend(models) all_models.extend(models)
return all_models return all_models
@ -1272,9 +1389,18 @@ class ModelPatcher:
for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN): for callback in self.get_all_callbacks(CallbacksMP.ON_PRE_RUN):
callback(self) callback(self)
def prepare_state(self, timestep): def prepare_state(self, timestep, model_options):
ignore_multigpu = model_options.get("ignore_multigpu", False)
for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE): for callback in self.get_all_callbacks(CallbacksMP.ON_PREPARE_STATE):
callback(self, timestep) callback(self, timestep, model_options)
if not ignore_multigpu and "multigpu_clones" in model_options:
model_options["ignore_multigpu"] = True
try:
for p in model_options["multigpu_clones"].values():
p: ModelPatcher
p.prepare_state(timestep, model_options)
finally:
model_options.pop("ignore_multigpu", None)
def restore_hook_patches(self): def restore_hook_patches(self):
if self.hook_patches_backup is not None: if self.hook_patches_backup is not None:
@ -1287,12 +1413,18 @@ class ModelPatcher:
def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]): def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: comfy.hooks.HookGroup, model_options: dict[str]):
curr_t = t[0] curr_t = t[0]
reset_current_hooks = False reset_current_hooks = False
multigpu_kf_changed_cache = None
transformer_options = model_options.get("transformer_options", {}) transformer_options = model_options.get("transformer_options", {})
for hook in hook_group.hooks: for hook in hook_group.hooks:
changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options) changed = hook.hook_keyframe.prepare_current_keyframe(curr_t=curr_t, transformer_options=transformer_options)
# if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref; # if keyframe changed, remove any cached HookGroups that contain hook with the same hook_ref;
# this will cause the weights to be recalculated when sampling # this will cause the weights to be recalculated when sampling
if changed: if changed:
# cache changed for multigpu usage
if "multigpu_clones" in model_options:
if multigpu_kf_changed_cache is None:
multigpu_kf_changed_cache = []
multigpu_kf_changed_cache.append(hook)
# reset current_hooks if contains hook that changed # reset current_hooks if contains hook that changed
if self.current_hooks is not None: if self.current_hooks is not None:
for current_hook in self.current_hooks.hooks: for current_hook in self.current_hooks.hooks:
@ -1304,6 +1436,28 @@ class ModelPatcher:
self.cached_hook_patches.pop(cached_group) self.cached_hook_patches.pop(cached_group)
if reset_current_hooks: if reset_current_hooks:
self.patch_hooks(None) self.patch_hooks(None)
if "multigpu_clones" in model_options:
for p in model_options["multigpu_clones"].values():
p: ModelPatcher
p._handle_changed_hook_keyframes(multigpu_kf_changed_cache)
def _handle_changed_hook_keyframes(self, kf_changed_cache: list[comfy.hooks.Hook]):
'Used to handle multigpu behavior inside prepare_hook_patches_current_keyframe.'
if kf_changed_cache is None:
return
reset_current_hooks = False
# reset current_hooks if contains hook that changed
for hook in kf_changed_cache:
if self.current_hooks is not None:
for current_hook in self.current_hooks.hooks:
if current_hook == hook:
reset_current_hooks = True
break
for cached_group in list(self.cached_hook_patches.keys()):
if cached_group.contains(hook):
self.cached_hook_patches.pop(cached_group)
if reset_current_hooks:
self.patch_hooks(None)
def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None, def register_all_hook_patches(self, hooks: comfy.hooks.HookGroup, target_dict: dict[str], model_options: dict=None,
registered: comfy.hooks.HookGroup = None): registered: comfy.hooks.HookGroup = None):
@ -1550,9 +1704,30 @@ class ModelPatcherDynamic(ModelPatcher):
super().__init__(model, load_device, offload_device, size, weight_inplace_update) super().__init__(model, load_device, offload_device, size, weight_inplace_update)
if not hasattr(self.model, "dynamic_vbars"): if not hasattr(self.model, "dynamic_vbars"):
self.model.dynamic_vbars = {} self.model.dynamic_vbars = {}
if not hasattr(self.model, "dynamic_pins"):
self.model.dynamic_pins = {}
self.register_load_device(self.load_device)
self.non_dynamic_delegate_model = None self.non_dynamic_delegate_model = None
assert load_device is not None assert load_device is not None
def register_load_device(self, device):
"""Ensure dynamic_pins has an entry for *device*.
Called from __init__ and also from any code that retargets an
already-constructed patcher to a new load_device (e.g. the
Select{Model,CLIP,VAE}Device selector nodes); without this entry
partially_unload_ram() raises KeyError when it tries to read the
per-device pin state.
"""
if device not in self.model.dynamic_pins:
self.model.dynamic_pins[device] = {
"weights": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]),
"patches": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]),
"hostbufs_initialized": False,
"failed": False,
"active": False,
}
def is_dynamic(self): def is_dynamic(self):
return True return True
@ -1589,6 +1764,16 @@ class ModelPatcherDynamic(ModelPatcher):
#use all ModelPatcherDynamic this is ignored and its all done dynamically. #use all ModelPatcherDynamic this is ignored and its all done dynamically.
return super().memory_required(input_shape=input_shape) * 1.3 + (1024 ** 3) return super().memory_required(input_shape=input_shape) * 1.3 + (1024 ** 3)
def restore_loaded_backups(self):
restored = self.model.model_loaded_weight_memory
for key in list(self.backup.keys()):
bk = self.backup.pop(key)
comfy.utils.set_attr_param(self.model, key, bk.weight)
for key in list(self.backup_buffers.keys()):
comfy.utils.set_attr_buffer(self.model, key, self.backup_buffers.pop(key))
self.model.model_loaded_weight_memory = 0
return restored
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False, dirty=False): def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False, dirty=False):
@ -1605,12 +1790,20 @@ class ModelPatcherDynamic(ModelPatcher):
num_patches = 0 num_patches = 0
allocated_size = 0 allocated_size = 0
self.model.model_loaded_weight_memory = 0 self.restore_loaded_backups()
with self.use_ejected(): with self.use_ejected():
self.unpatch_hooks() self.unpatch_hooks()
vbar = self._vbar_get(create=True) vbar = self._vbar_get(create=True)
pin_state = self.model.dynamic_pins[self.load_device]
if not pin_state["hostbufs_initialized"]:
hostbuf_size = comfy.model_management.pinned_hostbuf_size(self.model_size())
pin_state["weights"] = (comfy_aimdo.host_buffer.HostBuffer(0, 64 * 1024 * 1024, hostbuf_size), [], [-1], [0])
pin_state["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, hostbuf_size), [], [-1], [0])
pin_state["hostbufs_initialized"] = True
pin_state["failed"] = False
pin_state["active"] = True
if vbar is not None: if vbar is not None:
vbar.prioritize() vbar.prioritize()
@ -1636,7 +1829,9 @@ class ModelPatcherDynamic(ModelPatcher):
if key in self.patches: if key in self.patches:
if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape: if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape:
return (True, 0) return (True, 0)
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches)) lowvram_patch = LowVramPatch(key, self.patches)
lowvram_patch._pin_state = pin_state
setattr(m, param_key + "_lowvram_function", lowvram_patch)
num_patches += 1 num_patches += 1
else: else:
setattr(m, param_key + "_lowvram_function", None) setattr(m, param_key + "_lowvram_function", None)
@ -1653,6 +1848,9 @@ class ModelPatcherDynamic(ModelPatcher):
def force_load_param(self, param_key, device_to): def force_load_param(self, param_key, device_to):
key = key_param_name_to_key(n, param_key) key = key_param_name_to_key(n, param_key)
weight, _, _ = get_key_weight(self.model, key)
if weight is None:
return
if key in self.backup: if key in self.backup:
comfy.utils.set_attr_param(self.model, key, self.backup[key].weight) comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
self.patch_weight_to_device(key, device_to=device_to, force_cast=True) self.patch_weight_to_device(key, device_to=device_to, force_cast=True)
@ -1662,17 +1860,26 @@ class ModelPatcherDynamic(ModelPatcher):
if hasattr(m, "comfy_cast_weights"): if hasattr(m, "comfy_cast_weights"):
m.comfy_cast_weights = True m.comfy_cast_weights = True
m.pin_failed = False
m.seed_key = n m.seed_key = n
m._pin_state = pin_state
set_dirty(m, dirty) set_dirty(m, dirty)
#Models that mix tiny and giant weights can causing lopsided stream buffer
#rotations and stall. force the tinys over.
if module_mem > 16 * 1024:
force_load, v_weight_size = setup_param(self, m, n, "weight") force_load, v_weight_size = setup_param(self, m, n, "weight")
force_load_bias, v_weight_bias = setup_param(self, m, n, "bias") force_load_bias, v_weight_bias = setup_param(self, m, n, "bias")
force_load = force_load or force_load_bias force_load = force_load or force_load_bias
v_weight_size += v_weight_bias v_weight_size += v_weight_bias
if force_load: if force_load:
logging.info(f"Module {n} has resizing Lora - force loading") logging.info(f"Module {n} has resizing Lora - force loading")
else:
force_load=True
if force_load:
if hasattr(m, "_v"):
comfy_aimdo.model_vbar.vbar_unpin(m._v)
delattr(m, "_v")
force_load_param(self, "weight", device_to) force_load_param(self, "weight", device_to)
force_load_param(self, "bias", device_to) force_load_param(self, "bias", device_to)
else: else:
@ -1730,33 +1937,62 @@ class ModelPatcherDynamic(ModelPatcher):
freed = 0 if vbar is None else vbar.free_memory(memory_to_free) freed = 0 if vbar is None else vbar.free_memory(memory_to_free)
if freed < memory_to_free: if freed < memory_to_free:
for key in list(self.backup.keys()): freed += self.restore_loaded_backups()
bk = self.backup.pop(key)
comfy.utils.set_attr_param(self.model, key, bk.weight)
for key in list(self.backup_buffers.keys()):
comfy.utils.set_attr_buffer(self.model, key, self.backup_buffers.pop(key))
freed += self.model.model_loaded_weight_memory
self.model.model_loaded_weight_memory = 0
return freed return freed
def pinned_memory_size(self): def loaded_ram_size(self):
total = 0 return (self.model.dynamic_pins[self.load_device]["weights"][0].size +
loading = self._load_list(for_dynamic=True) self.model.dynamic_pins[self.load_device]["patches"][0].size)
for x in loading:
_, _, _, _, m, _ = x
pin = comfy.pinned_memory.get_pin(m)
if pin is not None:
total += pin.numel() * pin.element_size()
return total
def partially_unload_ram(self, ram_to_unload): def pinned_memory_size(self):
loading = self._load_list(for_dynamic=True, default_device=self.offload_device) return (self.model.dynamic_pins[self.load_device]["weights"][3][0] +
for x in loading: self.model.dynamic_pins[self.load_device]["patches"][3][0])
*_, m, _ = x
ram_to_unload -= comfy.pinned_memory.unpin_memory(m) def unregister_inactive_pins(self, ram_to_unload, subsets=[ "weights", "patches" ]):
freed = 0
pin_state = self.model.dynamic_pins[self.load_device]
for subset in subsets:
hostbuf, stack, stack_split, pinned_size = pin_state[subset]
split = stack_split[0]
while split >= 0:
module, offset = stack[split]
split -= 1
stack_split[0] = split
if not module._pin_registered:
continue
size = module._pin.numel() * module._pin.element_size()
if torch.cuda.cudart().cudaHostUnregister(module._pin.data_ptr()) != 0:
comfy.model_management.discard_cuda_async_error()
continue
module._pin_registered = False
comfy.model_management.TOTAL_PINNED_MEMORY = max(0, comfy.model_management.TOTAL_PINNED_MEMORY - size)
pinned_size[0] = max(0, pinned_size[0] - size)
freed += size
ram_to_unload -= size
if ram_to_unload <= 0: if ram_to_unload <= 0:
return return freed
return freed
def partially_unload_ram(self, ram_to_unload, subsets=[ "weights", "patches" ]):
freed = 0
pin_state = self.model.dynamic_pins[self.load_device]
for subset in subsets:
hostbuf, stack, stack_split, pinned_size = pin_state[subset]
while len(stack) > 0:
module, offset = stack.pop()
size = module._pin.numel() * module._pin.element_size()
del module._pin
hostbuf.truncate(offset, do_unregister=module._pin_registered)
stack_split[0] = min(stack_split[0], len(stack) - 1)
if module._pin_registered:
comfy.model_management.TOTAL_PINNED_MEMORY = max(0, comfy.model_management.TOTAL_PINNED_MEMORY - size)
pinned_size[0] = max(0, pinned_size[0] - size)
freed += size
ram_to_unload -= size
if ram_to_unload <= 0:
return freed
return freed
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False): def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
#This isn't used by the core at all and can only be to load a model out of #This isn't used by the core at all and can only be to load a model out of

248
comfy/multigpu.py Normal file
View File

@ -0,0 +1,248 @@
from __future__ import annotations
import queue
import threading
import torch
import logging
from collections import namedtuple
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
import comfy.utils
import comfy.patcher_extension
import comfy.model_management
class MultiGPUThreadPool:
"""Persistent thread pool for multi-GPU work distribution.
Maintains one worker thread per extra GPU device. Each thread calls
torch.cuda.set_device() once at startup so that compiled kernel caches
(inductor/triton) stay warm across diffusion steps.
"""
def __init__(self, devices: list[torch.device]):
self._workers: list[threading.Thread] = []
self._work_queues: dict[torch.device, queue.Queue] = {}
self._result_queues: dict[torch.device, queue.Queue] = {}
for device in devices:
wq = queue.Queue()
rq = queue.Queue()
self._work_queues[device] = wq
self._result_queues[device] = rq
t = threading.Thread(target=self._worker_loop, args=(device, wq, rq), daemon=True)
t.start()
self._workers.append(t)
def _worker_loop(self, device: torch.device, work_q: queue.Queue, result_q: queue.Queue):
try:
torch.cuda.set_device(device)
except Exception as e:
logging.error(f"MultiGPUThreadPool: failed to set device {device}: {e}")
while True:
item = work_q.get()
if item is None:
return
result_q.put((None, e))
return
while True:
item = work_q.get()
if item is None:
break
fn, args, kwargs = item
try:
result = fn(*args, **kwargs)
result_q.put((result, None))
except Exception as e:
result_q.put((None, e))
def submit(self, device: torch.device, fn, *args, **kwargs):
self._work_queues[device].put((fn, args, kwargs))
def get_result(self, device: torch.device):
return self._result_queues[device].get()
@property
def devices(self) -> list[torch.device]:
return list(self._work_queues.keys())
def shutdown(self):
for wq in self._work_queues.values():
wq.put(None) # sentinel
for t in self._workers:
t.join(timeout=5.0)
class GPUOptions:
def __init__(self, device_index: int, relative_speed: float):
self.device_index = device_index
self.relative_speed = relative_speed
def clone(self):
return GPUOptions(self.device_index, self.relative_speed)
def create_dict(self):
return {
"relative_speed": self.relative_speed
}
class GPUOptionsGroup:
def __init__(self):
self.options: dict[int, GPUOptions] = {}
def add(self, info: GPUOptions):
self.options[info.device_index] = info
def clone(self):
c = GPUOptionsGroup()
for opt in self.options.values():
c.add(opt)
return c
def register(self, model: ModelPatcher):
opts_dict = {}
# get devices that are valid for this model
devices: list[torch.device] = [model.load_device]
for extra_model in model.get_additional_models_with_key("multigpu"):
extra_model: ModelPatcher
devices.append(extra_model.load_device)
# create dictionary with actual device mapped to its GPUOptions
device_opts_list: list[GPUOptions] = []
for device in devices:
device_opts = self.options.get(device.index, GPUOptions(device_index=device.index, relative_speed=1.0))
opts_dict[device] = device_opts.create_dict()
device_opts_list.append(device_opts)
# make relative_speed relative to 1.0
min_speed = min([x.relative_speed for x in device_opts_list])
for value in opts_dict.values():
value['relative_speed'] /= min_speed
model.model_options['multigpu_options'] = opts_dict
def create_multigpu_deepclones(model: ModelPatcher, max_gpus: int, gpu_options: GPUOptionsGroup=None, reuse_loaded=False):
'Prepare ModelPatcher to contain deepclones of its BaseModel and related properties.'
model = model.clone()
# check if multigpu is already prepared - get the load devices from them if possible to exclude
skip_devices = set()
multigpu_models = model.get_additional_models_with_key("multigpu")
if len(multigpu_models) > 0:
for mm in multigpu_models:
skip_devices.add(mm.load_device)
skip_devices = list(skip_devices)
# Exclude the primary model's actual device, not the global current device:
# after SelectModelDevice(gpu:N) the primary may not live on the process's
# current CUDA device, and excluding the wrong device picks bad extras.
all_devices = comfy.model_management.get_all_torch_devices(exclude_current=False)
full_extra_devices = [d for d in all_devices if d != model.load_device]
limit_extra_devices = full_extra_devices[:max_gpus-1]
extra_devices = limit_extra_devices.copy()
# exclude skipped devices
for skip in skip_devices:
if skip in extra_devices:
extra_devices.remove(skip)
# create new deepclones
if len(extra_devices) > 0:
for device in extra_devices:
device_patcher = None
if reuse_loaded:
# Only reuse a previously-loaded MultiGPU clone. A SelectModelDevice
# patcher on the same device shares clone_base_uuid but has
# is_multigpu_base_clone=False, which would later be filtered out by
# prepare_model_patcher_multigpu_clones() and silently shrink the
# work split back to one GPU.
loaded_models: list[ModelPatcher] = comfy.model_management.loaded_models()
for lm in loaded_models:
if lm.model is None:
continue
if lm.load_device != device:
continue
if lm.clone_base_uuid != model.clone_base_uuid:
continue
if not getattr(lm, "is_multigpu_base_clone", False):
continue
device_patcher = lm.clone()
logging.info(f"Reusing loaded multigpu deepclone of {device_patcher.model.__class__.__name__} for {device}")
break
if device_patcher is None:
device_patcher = model.deepclone_multigpu(new_load_device=device)
# Always flag the clone; whether reused or freshly deepcloned, it must
# advertise itself as a MultiGPU base clone so the cond scheduler picks
# it up in prepare_model_patcher_multigpu_clones().
device_patcher.is_multigpu_base_clone = True
multigpu_models = model.get_additional_models_with_key("multigpu")
multigpu_models.append(device_patcher)
model.set_additional_models("multigpu", multigpu_models)
model.match_multigpu_clones()
if gpu_options is None:
gpu_options = GPUOptionsGroup()
gpu_options.register(model)
else:
logging.info("No extra torch devices need initialization, skipping initializing MultiGPU Work Units.")
# only keep model clones that don't go 'past' the intended max_gpu count;
# this prunes any inherited multigpu clones whose load_device is no longer allowed
# when max_gpus is lowered between runs.
allowed_devices = set(limit_extra_devices)
allowed_devices.add(model.load_device)
multigpu_models = model.get_additional_models_with_key("multigpu")
new_multigpu_models = [m for m in multigpu_models if m.load_device in allowed_devices]
if len(new_multigpu_models) != len(multigpu_models):
model.set_additional_models("multigpu", new_multigpu_models)
model.match_multigpu_clones()
return model
LoadBalance = namedtuple('LoadBalance', ['work_per_device', 'idle_time'])
def load_balance_devices(model_options: dict[str], total_work: int, return_idle_time=False, work_normalized: int=None):
'Optimize work assigned to different devices, accounting for their relative speeds and splittable work.'
opts_dict = model_options['multigpu_options']
devices = list(model_options['multigpu_clones'].keys())
speed_per_device = []
work_per_device = []
# get sum of each device's relative_speed
total_speed = 0.0
for opts in opts_dict.values():
total_speed += opts['relative_speed']
# get relative work for each device;
# obtained by w = (W*r)/R
for device in devices:
relative_speed = opts_dict[device]['relative_speed']
relative_work = (total_work*relative_speed) / total_speed
speed_per_device.append(relative_speed)
work_per_device.append(relative_work)
# relative work must be expressed in whole numbers, but likely is a decimal;
# perform rounding while maintaining total sum equal to total work (sum of relative works)
work_per_device = round_preserved(work_per_device)
dict_work_per_device = {}
for device, relative_work in zip(devices, work_per_device):
dict_work_per_device[device] = relative_work
if not return_idle_time:
return LoadBalance(dict_work_per_device, None)
# divide relative work by relative speed to get estimated completion time of said work by each device;
# time here is relative and does not correspond to real-world units
completion_time = [w/r for w,r in zip(work_per_device, speed_per_device)]
# calculate relative time spent by the devices waiting on each other after their work is completed
idle_time = abs(min(completion_time) - max(completion_time))
# if need to compare work idle time, need to normalize to a common total work
if work_normalized:
idle_time *= (work_normalized/total_work)
return LoadBalance(dict_work_per_device, idle_time)
def round_preserved(values: list[float]):
'Round all values in a list, preserving the combined sum of values.'
# get floor of values; casting to int does it too
floored = [int(x) for x in values]
total_floored = sum(floored)
# get remainder to distribute
remainder = round(sum(values)) - total_floored
# pair values with fractional portions
fractional = [(i, x-floored[i]) for i, x in enumerate(values)]
# sort by fractional part in descending order
fractional.sort(key=lambda x: x[1], reverse=True)
# distribute the remainder
for i in range(remainder):
index = fractional[i][0]
floored[index] += 1
return floored

View File

@ -18,6 +18,7 @@
import torch import torch
import logging import logging
import contextlib
import comfy.model_management import comfy.model_management
from comfy.cli_args import args, PerformanceFeature from comfy.cli_args import args, PerformanceFeature
import comfy.float import comfy.float
@ -75,6 +76,8 @@ except:
cast_to = comfy.model_management.cast_to #TODO: remove once no more references cast_to = comfy.model_management.cast_to #TODO: remove once no more references
STREAM_PIN_BUFFER_HEADROOM = 8 * 1024 * 1024
def cast_to_input(weight, input, non_blocking=False, copy=True): def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
@ -91,6 +94,9 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
offload_stream = None offload_stream = None
cast_buffer = None cast_buffer = None
cast_buffer_offset = 0 cast_buffer_offset = 0
stream_pin_hostbuf = None
stream_pin_offset = 0
stream_pin_queue = []
def ensure_offload_stream(module, required_size, check_largest): def ensure_offload_stream(module, required_size, check_largest):
nonlocal offload_stream nonlocal offload_stream
@ -124,6 +130,22 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
cast_buffer_offset += buffer_size cast_buffer_offset += buffer_size
return buffer return buffer
def get_stream_pin_buffer_offset(buffer_size):
nonlocal stream_pin_hostbuf
nonlocal stream_pin_offset
if buffer_size == 0 or offload_stream is None:
return None
if stream_pin_hostbuf is None:
stream_pin_hostbuf = comfy.model_management.get_pin_buffer(offload_stream)
if stream_pin_hostbuf is None:
return None
offset = stream_pin_offset
stream_pin_offset += buffer_size
return offset
for s in comfy_modules: for s in comfy_modules:
signature = comfy_aimdo.model_vbar.vbar_fault(s._v) signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
@ -162,23 +184,47 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
if xfer_dest is None: if xfer_dest is None:
xfer_dest = get_cast_buffer(dest_size) xfer_dest = get_cast_buffer(dest_size)
if signature is None and pin is None: def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream):
comfy.pinned_memory.pin_memory(s) if xfer_source is not None:
pin = comfy.pinned_memory.get_pin(s) if getattr(xfer_source, "is_lowvram_patch", False):
xfer_source.prepare(xfer_dest, stream, copy=True, commit=False)
else: else:
pin = None comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream)
def handle_pin(m, pin, source, dest, subset="weights", size=None):
if pin is not None: if pin is not None:
comfy.model_management.cast_to_gathered(xfer_source, pin) cast_maybe_lowvram_patch([pin], dest, offload_stream)
xfer_source = [ pin ] return
#send it over if signature is None:
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream) comfy.pinned_memory.pin_memory(m, subset=subset, size=size)
pin = comfy.pinned_memory.get_pin(m, subset=subset)
if pin is not None:
if isinstance(source, list):
comfy.model_management.cast_to_gathered(source, pin, non_blocking=non_blocking, stream=offload_stream, r2=dest)
else:
cast_maybe_lowvram_patch(source, pin, None)
cast_maybe_lowvram_patch([ pin ], dest, offload_stream)
return
if pin is None:
pin_offset = get_stream_pin_buffer_offset(size)
if pin_offset is not None:
stream_pin_queue.append((source, pin_offset, size, dest))
return
cast_maybe_lowvram_patch(source, dest, offload_stream)
handle_pin(s, pin, xfer_source, xfer_dest, size=dest_size)
for param_key in ("weight", "bias"): for param_key in ("weight", "bias"):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None) lowvram_source = getattr(s, param_key + "_lowvram_function", None)
if lowvram_fn is not None: if lowvram_source is not None:
ensure_offload_stream(s, cast_buffer_offset, False) ensure_offload_stream(s, cast_buffer_offset, False)
lowvram_fn.prepare(lambda size: get_cast_buffer(size), offload_stream) lowvram_size = lowvram_source.memory_required()
lowvram_dest = get_cast_buffer(lowvram_size)
lowvram_source.prepare(lowvram_dest, None, copy=False, commit=True)
pin = comfy.pinned_memory.get_pin(lowvram_source, subset="patches")
handle_pin(lowvram_source, pin, lowvram_source, lowvram_dest, subset="patches", size=lowvram_size)
prefetch["xfer_dest"] = xfer_dest prefetch["xfer_dest"] = xfer_dest
prefetch["cast_dest"] = cast_dest prefetch["cast_dest"] = cast_dest
@ -186,6 +232,23 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
prefetch["needs_cast"] = needs_cast prefetch["needs_cast"] = needs_cast
s._prefetch = prefetch s._prefetch = prefetch
if stream_pin_offset > 0:
if stream_pin_hostbuf.size < stream_pin_offset:
if not comfy.model_management.resize_pin_buffer(stream_pin_hostbuf, stream_pin_offset + STREAM_PIN_BUFFER_HEADROOM):
for xfer_source, _, _, xfer_dest in stream_pin_queue:
cast_maybe_lowvram_patch(xfer_source, xfer_dest, offload_stream)
return offload_stream
stream_pin_tensor = comfy_aimdo.torch.hostbuf_to_tensor(stream_pin_hostbuf)
stream_pin_tensor.untyped_storage()._comfy_hostbuf = stream_pin_hostbuf
for xfer_source, pin_offset, pin_size, xfer_dest in stream_pin_queue:
pin = stream_pin_tensor[pin_offset:pin_offset + pin_size]
if isinstance(xfer_source, list):
comfy.model_management.cast_to_gathered(xfer_source, pin, non_blocking=non_blocking, stream=offload_stream, r2=xfer_dest)
else:
cast_maybe_lowvram_patch(xfer_source, pin, None)
comfy.model_management.cast_to_gathered([ pin ], xfer_dest, non_blocking=non_blocking, stream=offload_stream)
stream_pin_hostbuf._comfy_event = offload_stream.record_event()
return offload_stream return offload_stream
@ -985,6 +1048,144 @@ class QuantLinearFunc(torch.autograd.Function):
return grad_input, grad_weight, grad_bias, None, None, None return grad_input, grad_weight, grad_bias, None, None, None
# Quantized-weight module helpers
def _quantized_apply(module, fn, recurse=True):
"""Re-wrap Parameters after fn so .to()/.cuda() propagate through QuantizedTensor weights."""
if recurse:
for child in module.children():
child._apply(fn)
for key, param in module._parameters.items():
if param is None:
continue
p = fn(param)
if (not torch.is_inference_mode_enabled()) and p.is_inference():
p = p.clone()
module.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
for key, buf in module._buffers.items():
if buf is not None:
module._buffers[key] = fn(buf)
return module
def _load_quantized_module(module, super_load, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs, load_extra_params=False):
"""Shared _load_from_state_dict body for quantized-weight modules.
Pops weight (+ scales, +/- extras), populates module.weight as a Parameter
or Parameter-wrapped QuantizedTensor, then calls super_load and strips
consumed keys from missing_keys. Reads compute_dtype from factory_kwargs
and disabled formats from module._disabled_formats.
"""
device = module.factory_kwargs["device"]
compute_dtype = module.factory_kwargs["dtype"]
disabled_formats = module._disabled_formats
layer_name = prefix.rstrip('.')
weight = state_dict.pop(f"{prefix}weight", None)
if weight is None:
logging.warning(f"Missing weight for layer {layer_name}")
module.weight = None
return
manually_loaded_keys = [f"{prefix}weight"]
def pop_scale(name, dtype=None):
key = f"{prefix}{name}"
v = state_dict.pop(key, None)
if v is not None:
v = v.to(device=device)
if dtype is not None:
v = v.view(dtype=dtype)
manually_loaded_keys.append(key)
return v
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
if layer_conf is not None:
layer_conf = json.loads(layer_conf.numpy().tobytes())
if layer_conf is None:
module.weight = torch.nn.Parameter(weight.to(device=device, dtype=compute_dtype), requires_grad=False)
else:
module.quant_format = layer_conf.get("format", None)
module._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
if not module._full_precision_mm:
module._full_precision_mm = module._full_precision_mm_config
if module.quant_format in disabled_formats:
module._full_precision_mm = True
if module.quant_format is None:
raise ValueError(f"Unknown quantization format for layer {layer_name}")
qconfig = QUANT_ALGOS[module.quant_format]
module.layout_type = qconfig["comfy_tensor_layout"]
layout_cls = get_layout_class(module.layout_type)
# Per-format scales; fp8 dtype views handle both legacy uint8-on-disk and native fp8.
if module.quant_format in ("float8_e4m3fn", "float8_e5m2"):
scales = {"scale": pop_scale("weight_scale")}
elif module.quant_format == "mxfp8":
bs = pop_scale("weight_scale", torch.float8_e8m0fnu)
if bs is None:
raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}")
scales = {"scale": bs}
elif module.quant_format == "nvfp4":
ts = pop_scale("weight_scale_2")
bs = pop_scale("weight_scale", torch.float8_e4m3fn)
if ts is None or bs is None:
raise ValueError(f"Missing NVFP4 scales for layer {layer_name}")
scales = {"scale": ts, "block_scale": bs}
else:
raise ValueError(f"Unsupported quantization format: {module.quant_format}")
params = layout_cls.Params(**scales, orig_dtype=compute_dtype, orig_shape=module._orig_shape)
module.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), module.layout_type, params),
requires_grad=False,
)
if load_extra_params:
for param_name in qconfig["parameters"]:
if param_name in {"weight_scale", "weight_scale_2"}:
continue
param_key = f"{prefix}{param_name}"
_v = state_dict.pop(param_key, None)
if _v is None:
continue
module.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
manually_loaded_keys.append(param_key)
super_load(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
for key in manually_loaded_keys:
if key in missing_keys:
missing_keys.remove(key)
def _quantized_weight_state_dict(module, sd, prefix, extra_quant_conf=None, extra_quant_params=()):
"""Shared state_dict body. extra_quant_conf merges into the comfy_quant JSON;
extra_quant_params names attributes written as additional top-level keys."""
if not hasattr(module, 'weight'):
logging.warning(f"Warning: state dict on uninitialized op {prefix}")
return sd
bias = getattr(module, 'bias', None)
if bias is not None:
sd[f"{prefix}bias"] = bias
if module.weight is None:
return sd
if isinstance(module.weight, QuantizedTensor):
sd.update(module.weight.state_dict(f"{prefix}weight"))
quant_conf = {"format": module.quant_format}
if getattr(module, '_full_precision_mm_config', False):
quant_conf["full_precision_matrix_mult"] = True
if extra_quant_conf:
quant_conf.update(extra_quant_conf)
sd[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8)
for name in extra_quant_params:
value = getattr(module, name, None)
if value is not None:
sd[f"{prefix}{name}"] = value
else:
sd[f"{prefix}weight"] = module.weight
return sd
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]): def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
class MixedPrecisionOps(manual_cast): class MixedPrecisionOps(manual_cast):
@ -994,21 +1195,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
_disabled = disabled _disabled = disabled
class Linear(torch.nn.Module, CastWeightBiasOp): class Linear(torch.nn.Module, CastWeightBiasOp):
def __init__( _disabled_formats = disabled
self,
in_features: int, def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
super().__init__() super().__init__()
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype} self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
# self.factory_kwargs = {"device": device, "dtype": dtype}
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
self._orig_shape = (out_features, in_features)
if bias: if bias:
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs)) self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
else: else:
@ -1021,151 +1217,12 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def reset_parameters(self): def reset_parameters(self):
return None return None
def _load_scale_param(self, state_dict, prefix, param_name, device, manually_loaded_keys, dtype=None): def _load_from_state_dict(self, *args):
key = f"{prefix}{param_name}" _load_quantized_module(self, super()._load_from_state_dict, *args, load_extra_params=True)
value = state_dict.pop(key, None)
if value is not None:
value = value.to(device=device)
if dtype is not None:
value = value.view(dtype=dtype)
manually_loaded_keys.append(key)
return value
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
device = self.factory_kwargs["device"]
layer_name = prefix.rstrip('.')
weight_key = f"{prefix}weight"
weight = state_dict.pop(weight_key, None)
if weight is None:
logging.warning(f"Missing weight for layer {layer_name}")
self.weight = None
return
manually_loaded_keys = [weight_key]
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
if layer_conf is not None:
layer_conf = json.loads(layer_conf.numpy().tobytes())
if layer_conf is None:
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
else:
self.quant_format = layer_conf.get("format", None)
self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
if not self._full_precision_mm:
self._full_precision_mm = self._full_precision_mm_config
if self.quant_format in MixedPrecisionOps._disabled:
self._full_precision_mm = True
if self.quant_format is None:
raise ValueError(f"Unknown quantization format for layer {layer_name}")
qconfig = QUANT_ALGOS[self.quant_format]
self.layout_type = qconfig["comfy_tensor_layout"]
layout_cls = get_layout_class(self.layout_type)
# Load format-specific parameters
if self.quant_format in ["float8_e4m3fn", "float8_e5m2"]:
# FP8: single tensor scale
scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
params = layout_cls.Params(
scale=scale,
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=(self.out_features, self.in_features),
)
elif self.quant_format == "mxfp8":
# MXFP8: E8M0 block scales stored as uint8 in safetensors
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
dtype=torch.uint8)
if block_scale is None:
raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}")
block_scale = block_scale.view(torch.float8_e8m0fnu)
params = layout_cls.Params(
scale=block_scale,
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=(self.out_features, self.in_features),
)
elif self.quant_format == "nvfp4":
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
dtype=torch.float8_e4m3fn)
if tensor_scale is None or block_scale is None:
raise ValueError(f"Missing NVFP4 scales for layer {layer_name}")
params = layout_cls.Params(
scale=tensor_scale,
block_scale=block_scale,
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=(self.out_features, self.in_features),
)
else:
raise ValueError(f"Unsupported quantization format: {self.quant_format}")
self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params),
requires_grad=False
)
for param_name in qconfig["parameters"]:
if param_name in {"weight_scale", "weight_scale_2"}:
continue # Already handled above
param_key = f"{prefix}{param_name}"
_v = state_dict.pop(param_key, None)
if _v is None:
continue
self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
manually_loaded_keys.append(param_key)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
for key in manually_loaded_keys:
if key in missing_keys:
missing_keys.remove(key)
def state_dict(self, *args, destination=None, prefix="", **kwargs): def state_dict(self, *args, destination=None, prefix="", **kwargs):
if destination is not None: sd = destination if destination is not None else {}
sd = destination return _quantized_weight_state_dict(self, sd, prefix, extra_quant_params=("input_scale",))
else:
sd = {}
if not hasattr(self, 'weight'):
logging.warning("Warning: state dict on uninitialized op {}".format(prefix))
return sd
if self.bias is not None:
sd["{}bias".format(prefix)] = self.bias
if self.weight is None:
return sd
if isinstance(self.weight, QuantizedTensor):
sd_out = self.weight.state_dict("{}weight".format(prefix))
for k in sd_out:
sd[k] = sd_out[k]
quant_conf = {"format": self.quant_format}
if self._full_precision_mm_config:
quant_conf["full_precision_matrix_mult"] = True
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
input_scale = getattr(self, 'input_scale', None)
if input_scale is not None:
sd["{}input_scale".format(prefix)] = input_scale
else:
sd["{}weight".format(prefix)] = self.weight
return sd
def _forward(self, input, weight, bias): def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias) return torch.nn.functional.linear(input, weight, bias)
@ -1255,25 +1312,126 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
self.weight = torch.nn.Parameter(weight, requires_grad=False) self.weight = torch.nn.Parameter(weight, requires_grad=False)
def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working
if recurse: return _quantized_apply(self, fn, recurse)
for module in self.children():
module._apply(fn)
for key, param in self._parameters.items(): class MoEExperts(torch.nn.Module, CastWeightBiasOp):
if param is None: """Container for E quantized expert weights, indexed via expert_weight(i).
continue
p = fn(param) The bank lives on self.weight as a single 3D tensor either a
if (not torch.is_inference_mode_enabled()) and p.is_inference(): compute_dtype Parameter or a Parameter wrapping a QuantizedTensor
p = p.clone() with leading expert dim.
self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
for key, buf in self._buffers.items(): State-dict layout matches mixed_precision_ops.Linear with a leading
if buf is not None: expert dim:
self._buffers[key] = fn(buf) {prefix}.weight quant data (storage_t), leading dim = E
return self {prefix}.weight_scale block / per-tensor scale
{prefix}.weight_scale_2 [E] or scalar NVFP4 only
{prefix}.bias [E, out_features] optional, compute_dtype
{prefix}.comfy_quant json -> {{"format": "...", "num_experts": E}}
Without comfy_quant the weight loads as a plain compute_dtype 3D Parameter [E, out, in].
"""
_disabled_formats = disabled
def __init__(self, num_experts: int, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
super().__init__()
self.num_experts = num_experts
self.in_features = in_features
self.out_features = out_features
self._orig_shape = (num_experts, out_features, in_features)
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
if bias:
self.bias = torch.nn.Parameter(torch.empty(num_experts, out_features, **self.factory_kwargs))
else:
self.register_parameter("bias", None)
# Populated by _load_from_state_dict:
self.weight = None
self.quant_format = None
self.layout_type = None
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
self._full_precision_mm_config = False
self._resident_bank = None
def reset_parameters(self):
return None
def _apply(self, fn, recurse=True):
return _quantized_apply(self, fn, recurse)
def _load_from_state_dict(self, *args):
_load_quantized_module(self, super()._load_from_state_dict, *args, load_extra_params=False)
def expert_weight(self, i: int):
"""Expert i's weight (Tensor or per-expert QuantizedTensor view)."""
if isinstance(self.weight, QuantizedTensor):
return self._expert_qt_from(self.weight, i)
return self.weight[i]
@contextlib.contextmanager
def bank_resident(self, input):
"""Cast the whole bank once; expert_linear inside reuses the cast.
Not re-entrant do not nest calls on the same instance.
"""
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
self._resident_bank = (weight, bias)
try:
yield self
finally:
self._resident_bank = None
uncast_bias_weight(self, weight, bias, offload_stream)
def expert_linear(self, input: torch.Tensor, i: int) -> torch.Tensor:
"""Linear against expert i's weight (with optional bias)."""
resident = getattr(self, "_resident_bank", None)
if resident is not None:
weight, bias = resident
return self._expert_linear_impl(input, weight, bias, i)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
try:
return self._expert_linear_impl(input, weight, bias, i)
finally:
uncast_bias_weight(self, weight, bias, offload_stream)
def _expert_linear_impl(self, input, weight, bias, i):
if isinstance(weight, QuantizedTensor):
qw = self._expert_qt_from(weight, i)
else:
qw = weight[i]
b = cast_to_input(bias[i], input, copy=False) if bias is not None else None
if isinstance(qw, QuantizedTensor):
use_fast = (
not self._full_precision_mm
and qw.layout_cls.supports_fast_matmul()
and input.dim() == 2
)
if use_fast:
qin = QuantizedTensor.from_float(input, self.layout_type)
return torch.nn.functional.linear(qin, qw, b)
out = input @ qw.dequantize().t()
return out + b if b is not None else out
return torch.nn.functional.linear(input, qw, b)
def _expert_qt_from(self, weight: QuantizedTensor, i: int) -> QuantizedTensor:
"""Build a per-expert QuantizedTensor by indexing into a resident bank."""
params = weight._params
kwargs = {
"scale": params.scale[i] if params.scale.dim() else params.scale,
"orig_dtype": params.orig_dtype,
"orig_shape": (self.out_features, self.in_features),
}
if hasattr(params, "block_scale"): # NVFP4
kwargs["block_scale"] = params.block_scale[i]
return QuantizedTensor(weight._qdata[i], weight._layout_cls, type(params)(**kwargs))
def state_dict(self, *args, destination=None, prefix="", **kwargs):
sd = destination if destination is not None else {}
return _quantized_weight_state_dict(self, sd, prefix, extra_quant_conf={"num_experts": self.num_experts})
class Embedding(manual_cast.Embedding): class Embedding(manual_cast.Embedding):
def _load_from_state_dict(self, state_dict, prefix, local_metadata, def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
strict, missing_keys, unexpected_keys, error_msgs):
weight_key = f"{prefix}weight" weight_key = f"{prefix}weight"
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None) layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
if layer_conf is not None: if layer_conf is not None:
@ -1281,14 +1439,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
# Only fp8 makes sense for embeddings (per-row dequant via index select). # Only fp8 makes sense for embeddings (per-row dequant via index select).
# Block-scaled formats (NVFP4, MXFP8) can't do per-row lookup efficiently. # Block-scaled formats (NVFP4, MXFP8) can't do per-row lookup efficiently.
quant_format = layer_conf.get("format", None) if layer_conf is not None else None quant_format = layer_conf.get("format") if layer_conf is not None else None
if quant_format in ["float8_e4m3fn", "float8_e5m2"] and weight_key in state_dict: manually_loaded_keys = []
if quant_format in ("float8_e4m3fn", "float8_e5m2") and weight_key in state_dict:
self.quant_format = quant_format self.quant_format = quant_format
qconfig = QUANT_ALGOS[quant_format] qconfig = QUANT_ALGOS[quant_format]
self.layout_type = qconfig["comfy_tensor_layout"] self.layout_type = qconfig["comfy_tensor_layout"]
layout_cls = get_layout_class(self.layout_type) layout_cls = get_layout_class(self.layout_type)
weight = state_dict.pop(weight_key) weight = state_dict.pop(weight_key)
manually_loaded_keys = [weight_key] manually_loaded_keys.append(weight_key)
scale_key = f"{prefix}weight_scale" scale_key = f"{prefix}weight_scale"
scale = state_dict.pop(scale_key, None) scale = state_dict.pop(scale_key, None)
@ -1304,35 +1464,19 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
self.weight = torch.nn.Parameter( self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(dtype=qconfig["storage_t"]), qconfig["comfy_tensor_layout"], params), QuantizedTensor(weight.to(dtype=qconfig["storage_t"]), qconfig["comfy_tensor_layout"], params),
requires_grad=False) requires_grad=False)
elif layer_conf is not None:
# Unsupported format — restore the marker so it round-trips; fall through to default load.
state_dict[f"{prefix}comfy_quant"] = torch.tensor(
list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
for k in manually_loaded_keys: for k in manually_loaded_keys:
if k in missing_keys: if k in missing_keys:
missing_keys.remove(k) missing_keys.remove(k)
else:
if layer_conf is not None:
state_dict[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
def state_dict(self, *args, destination=None, prefix="", **kwargs): def state_dict(self, *args, destination=None, prefix="", **kwargs):
if destination is not None: sd = destination if destination is not None else {}
sd = destination return _quantized_weight_state_dict(self, sd, prefix)
else:
sd = {}
if not hasattr(self, 'weight') or self.weight is None:
return sd
if isinstance(self.weight, QuantizedTensor):
sd_out = self.weight.state_dict("{}weight".format(prefix))
for k in sd_out:
sd[k] = sd_out[k]
quant_conf = {"format": self.quant_format}
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
else:
sd["{}weight".format(prefix)] = self.weight
return sd
def forward_comfy_cast_weights(self, input, out_dtype=None): def forward_comfy_cast_weights(self, input, out_dtype=None):
weight = self.weight weight = self.weight

View File

@ -1,8 +1,9 @@
from __future__ import annotations
from typing import Callable from typing import Callable
class CallbacksMP: class CallbacksMP:
ON_CLONE = "on_clone" ON_CLONE = "on_clone"
ON_DEEPCLONE_MULTIGPU = "on_deepclone_multigpu"
ON_MATCH_MULTIGPU_CLONES = "on_match_multigpu_clones"
ON_LOAD = "on_load_after" ON_LOAD = "on_load_after"
ON_DETACH = "on_detach_after" ON_DETACH = "on_detach_after"
ON_CLEANUP = "on_cleanup" ON_CLEANUP = "on_cleanup"

View File

@ -2,42 +2,62 @@ import comfy.model_management
import comfy.memory_management import comfy.memory_management
import comfy_aimdo.host_buffer import comfy_aimdo.host_buffer
import comfy_aimdo.torch import comfy_aimdo.torch
import torch
from comfy.cli_args import args from comfy.cli_args import args
def get_pin(module): def get_pin(module, subset="weights"):
return getattr(module, "_pin", None) pin = getattr(module, "_pin", None)
if pin is None or module._pin_registered or args.disable_pinned_memory:
return pin
def pin_memory(module): _, _, stack_split, pinned_size = module._pin_state[subset]
if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None: size = pin.nbytes
comfy.model_management.ensure_pin_registerable(size)
if torch.cuda.cudart().cudaHostRegister(pin.data_ptr(), size, 1) != 0:
comfy.model_management.discard_cuda_async_error()
return pin
module._pin_registered = True
stack_split[0] = max(stack_split[0], module._pin_stack_index)
comfy.model_management.TOTAL_PINNED_MEMORY += size
pinned_size[0] += size
return pin
def pin_memory(module, subset="weights", size=None):
pin_state = module._pin_state
if args.disable_pinned_memory:
return return
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ]) pin = get_pin(module, subset)
if pin is not None or pin_state["failed"]:
return
if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY: hostbuf, stack, stack_split, pinned_size = pin_state[subset]
module.pin_failed = True if size is None:
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
offset = hostbuf.size
registerable_size = size + max(0, hostbuf.size - pinned_size[0])
comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM)
if (not comfy.model_management.ensure_pin_budget(size) or
not comfy.model_management.ensure_pin_registerable(registerable_size)):
pin_state["failed"] = True
return False return False
try: try:
hostbuf = comfy_aimdo.host_buffer.HostBuffer(size) hostbuf.extend(size=size)
except RuntimeError: except RuntimeError:
module.pin_failed = True pin_state["failed"] = True
return False return False
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf) module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size]
module._pin_hostbuf = hostbuf module._pin.untyped_storage()._comfy_hostbuf = hostbuf
stack.append((module, offset))
module._pin_registered = True
module._pin_stack_index = len(stack) - 1
stack_split[0] = max(stack_split[0], module._pin_stack_index)
comfy.model_management.TOTAL_PINNED_MEMORY += size comfy.model_management.TOTAL_PINNED_MEMORY += size
pinned_size[0] += size
return True return True
def unpin_memory(module):
if get_pin(module) is None:
return 0
size = module._pin.numel() * module._pin.element_size()
comfy.model_management.TOTAL_PINNED_MEMORY -= size
if comfy.model_management.TOTAL_PINNED_MEMORY < 0:
comfy.model_management.TOTAL_PINNED_MEMORY = 0
del module._pin
del module._pin_hostbuf
return size

View File

@ -1,16 +1,18 @@
from __future__ import annotations from __future__ import annotations
import torch
import uuid import uuid
import math import math
import collections import collections
import comfy.model_management import comfy.model_management
import comfy.conds import comfy.conds
import comfy.model_patcher
import comfy.utils import comfy.utils
import comfy.hooks import comfy.hooks
import comfy.patcher_extension import comfy.patcher_extension
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.model_base import BaseModel from comfy.model_base import BaseModel
from comfy.model_patcher import ModelPatcher
from comfy.controlnet import ControlBase from comfy.controlnet import ControlBase
def prepare_mask(noise_mask, shape, device): def prepare_mask(noise_mask, shape, device):
@ -119,6 +121,47 @@ def cleanup_additional_models(models):
if hasattr(m, 'cleanup'): if hasattr(m, 'cleanup'):
m.cleanup() m.cleanup()
def preprocess_multigpu_conds(conds: dict[str, list[dict[str]]], model: ModelPatcher, model_options: dict[str]):
'''If multigpu acceleration required, creates deepclones of ControlNets and GLIGEN per device.'''
multigpu_models: list[ModelPatcher] = model.get_additional_models_with_key("multigpu")
if len(multigpu_models) == 0:
return
extra_devices = [x.load_device for x in multigpu_models]
# handle controlnets
controlnets: set[ControlBase] = set()
for k in conds:
for kk in conds[k]:
if 'control' in kk:
controlnets.add(kk['control'])
if len(controlnets) > 0:
# first, unload all controlnet clones
for cnet in list(controlnets):
cnet_models = cnet.get_models()
for cm in cnet_models:
comfy.model_management.unload_model_and_clones(cm, unload_additional_models=True)
# next, make sure each controlnet has a deepclone for all relevant devices
for cnet in controlnets:
curr_cnet = cnet
while curr_cnet is not None:
for device in extra_devices:
if device not in curr_cnet.multigpu_clones:
curr_cnet.deepclone_multigpu(device, autoregister=True)
curr_cnet = curr_cnet.previous_controlnet
# since all device clones are now present, recreate the linked list for cloned cnets per device
for cnet in controlnets:
curr_cnet = cnet
while curr_cnet is not None:
prev_cnet = curr_cnet.previous_controlnet
for device in extra_devices:
device_cnet = curr_cnet.get_instance_for_device(device)
prev_device_cnet = None
if prev_cnet is not None:
prev_device_cnet = prev_cnet.get_instance_for_device(device)
device_cnet.set_previous_controlnet(prev_device_cnet)
curr_cnet = prev_cnet
# potentially handle gligen - since not widely used, ignored for now
def estimate_memory(model, noise_shape, conds): def estimate_memory(model, noise_shape, conds):
cond_shapes = collections.defaultdict(list) cond_shapes = collections.defaultdict(list)
cond_shapes_min = {} cond_shapes_min = {}
@ -143,7 +186,8 @@ def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None
return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload) return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload)
def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False): def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
real_model: BaseModel = None model.match_multigpu_clones()
preprocess_multigpu_conds(conds, model, model_options)
models, inference_memory = get_additional_models(conds, model.model_dtype()) models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options) models += get_additional_models_from_model_options(model_options)
models += model.get_nested_additional_models() # TODO: does this require inference_memory update? models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
@ -155,7 +199,7 @@ def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=Non
memory_required += inference_memory memory_required += inference_memory
minimum_memory_required += inference_memory minimum_memory_required += inference_memory
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load) comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
real_model = model.model real_model: BaseModel = model.model
return real_model, conds, models return real_model, conds, models
@ -201,3 +245,18 @@ def prepare_model_patcher(model: ModelPatcher, conds, model_options: dict):
comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name], comfy.patcher_extension.merge_nested_dicts(to_load_options.setdefault(wc_name, {}), model_options["transformer_options"][wc_name],
copy_dict1=False) copy_dict1=False)
return to_load_options return to_load_options
def prepare_model_patcher_multigpu_clones(model_patcher: ModelPatcher, loaded_models: list[ModelPatcher], model_options: dict):
'''
In case multigpu acceleration is enabled, prep ModelPatchers for each device.
'''
multigpu_patchers: list[ModelPatcher] = [x for x in loaded_models if x.is_multigpu_base_clone]
if len(multigpu_patchers) > 0:
multigpu_dict: dict[torch.device, ModelPatcher] = {}
multigpu_dict[model_patcher.load_device] = model_patcher
for x in multigpu_patchers:
x.hook_patches = comfy.model_patcher.create_hook_patches_clone(model_patcher.hook_patches, copy_tuples=True)
x.hook_mode = model_patcher.hook_mode # match main model's hook_mode
multigpu_dict[x.load_device] = x
model_options["multigpu_clones"] = multigpu_dict
return multigpu_patchers

View File

@ -1,7 +1,9 @@
from __future__ import annotations from __future__ import annotations
import comfy.model_management
from .k_diffusion import sampling as k_diffusion_sampling from .k_diffusion import sampling as k_diffusion_sampling
from .extra_samplers import uni_pc from .extra_samplers import uni_pc
from typing import TYPE_CHECKING, Callable, NamedTuple from typing import TYPE_CHECKING, Callable, NamedTuple, Any
if TYPE_CHECKING: if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher from comfy.model_patcher import ModelPatcher
from comfy.model_base import BaseModel from comfy.model_base import BaseModel
@ -16,6 +18,7 @@ import comfy.model_patcher
import comfy.patcher_extension import comfy.patcher_extension
import comfy.hooks import comfy.hooks
import comfy.context_windows import comfy.context_windows
import comfy.multigpu
import comfy.utils import comfy.utils
import scipy.stats import scipy.stats
import numpy import numpy
@ -141,7 +144,7 @@ def can_concat_cond(c1, c2):
return cond_equal_size(c1.conditioning, c2.conditioning) return cond_equal_size(c1.conditioning, c2.conditioning)
def cond_cat(c_list): def cond_cat(c_list, device=None):
temp = {} temp = {}
for x in c_list: for x in c_list:
for k in x: for k in x:
@ -153,6 +156,8 @@ def cond_cat(c_list):
for k in temp: for k in temp:
conds = temp[k] conds = temp[k]
out[k] = conds[0].concat(conds[1:]) out[k] = conds[0].concat(conds[1:])
if device is not None and hasattr(out[k], 'to'):
out[k] = out[k].to(device)
return out return out
@ -212,7 +217,12 @@ def _calc_cond_batch_outer(model: BaseModel, conds: list[list[dict]], x_in: torc
) )
return executor.execute(model, conds, x_in, timestep, model_options) return executor.execute(model, conds, x_in, timestep, model_options)
def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep, model_options): def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
# NOTE: keep in sync with _calc_cond_batch_multigpu below. Shared logic
# (hooked_to_run accumulation, memory-fit batching, per-chunk output
# aggregation) is duplicated there with per-device scheduling layered on top.
if 'multigpu_clones' in model_options:
return _calc_cond_batch_multigpu(model, conds, x_in, timestep, model_options)
out_conds = [] out_conds = []
out_counts = [] out_counts = []
# separate conds by matching hooks # separate conds by matching hooks
@ -244,7 +254,7 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
if has_default_conds: if has_default_conds:
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options) finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
model.current_patcher.prepare_state(timestep) model.current_patcher.prepare_state(timestep, model_options)
# run every hooked_to_run separately # run every hooked_to_run separately
for hooks, to_run in hooked_to_run.items(): for hooks, to_run in hooked_to_run.items():
@ -265,7 +275,6 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
cond_shapes = collections.defaultdict(list) cond_shapes = collections.defaultdict(list)
for tt in batch_amount: for tt in batch_amount:
cond = {k: v.size() for k, v in to_run[tt][0].conditioning.items()}
for k, v in to_run[tt][0].conditioning.items(): for k, v in to_run[tt][0].conditioning.items():
cond_shapes[k].append(v.size()) cond_shapes[k].append(v.size())
@ -345,6 +354,239 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
return out_conds return out_conds
def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
# NOTE: keep in sync with _calc_cond_batch above. Same conds-by-hooks
# accumulation, memory-fit batching, and output aggregation, but adds a
# per-device scheduler, per-device patcher/control lookup, tensor .to(device)
# placement, and MultiGPUThreadPool dispatch around the inner loop.
out_conds = []
out_counts = []
# separate conds by matching hooks
hooked_to_run: dict[comfy.hooks.HookGroup,list[tuple[tuple,int]]] = {}
default_conds = []
has_default_conds = False
output_device = x_in.device
for i in range(len(conds)):
out_conds.append(torch.zeros_like(x_in))
out_counts.append(torch.ones_like(x_in) * 1e-37)
cond = conds[i]
default_c = []
if cond is not None:
for x in cond:
if 'default' in x:
default_c.append(x)
has_default_conds = True
continue
p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
if p.hooks is not None:
model.current_patcher.prepare_hook_patches_current_keyframe(timestep, p.hooks, model_options)
hooked_to_run.setdefault(p.hooks, list())
hooked_to_run[p.hooks] += [(p, i)]
default_conds.append(default_c)
if has_default_conds:
finalize_default_conds(model, hooked_to_run, default_conds, x_in, timestep, model_options)
model.current_patcher.prepare_state(timestep, model_options)
devices = list(model_options['multigpu_clones'].keys())
device_batched_hooked_to_run: dict[torch.device, list[tuple[comfy.hooks.HookGroup, tuple]]] = {}
# Track conds currently scheduled per device; single source of truth for capacity checks.
device_load: dict[torch.device, int] = {d: 0 for d in devices}
total_conds = sum(len(to_run) for to_run in hooked_to_run.values())
conds_per_device = max(1, math.ceil(total_conds / len(devices)))
def next_available_device(start: int) -> tuple[int, torch.device]:
"""Return (index, device) for the next device with remaining capacity, starting at `start`.
Scans at most len(devices) positions, so this always terminates. Raises if no device
has remaining capacity, which would indicate a bug in conds_per_device accounting.
"""
for offset in range(len(devices)):
i = (start + offset) % len(devices)
if device_load[devices[i]] < conds_per_device:
return i, devices[i]
raise RuntimeError(
f"MultiGPU scheduler: all {len(devices)} devices at capacity "
f"({conds_per_device}) but conds remain to schedule"
)
# run every hooked_to_run separately
index_device = 0
for hooks, to_run in hooked_to_run.items():
while len(to_run) > 0:
index_device, current_device = next_available_device(index_device)
remaining_capacity = conds_per_device - device_load[current_device]
first = to_run[0]
first_shape = first[0][0].shape
# collect candidate indices that can be concatenated with `first`, up to remaining capacity
to_batch_temp = []
for x in range(len(to_run)):
if can_concat_cond(to_run[x][0], first[0]) and len(to_batch_temp) < remaining_capacity:
to_batch_temp += [x]
to_batch_temp.reverse()
to_batch = to_batch_temp[:1]
free_memory = comfy.model_management.get_free_memory(current_device)
for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
cond_shapes = collections.defaultdict(list)
for tt in batch_amount:
for k, v in to_run[tt][0].conditioning.items():
cond_shapes[k].append(v.size())
if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory:
to_batch = batch_amount
break
conds_to_batch = [to_run.pop(x) for x in to_batch]
device_load[current_device] += len(conds_to_batch)
device_batched_hooked_to_run.setdefault(current_device, []).append((hooks, conds_to_batch))
if device_load[current_device] >= conds_per_device:
index_device += 1
class thread_result(NamedTuple):
output: Any
mult: Any
area: Any
batch_chunks: int
cond_or_uncond: Any
error: Exception = None
def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]):
try:
# TODO: non-NVIDIA support -- guard with `if device.type == "cuda":` once
# we extend multigpu QA beyond CUDA. Unconditional call crashes on
# XPU/NPU/MPS/CPU/DirectML backends.
torch.cuda.set_device(device)
model_current: BaseModel = model_options["multigpu_clones"][device].model
# run every hooked_to_run separately
with torch.no_grad():
for hooks, to_batch in batch_tuple:
input_x = []
mult = []
c = []
cond_or_uncond = []
uuids = []
area = []
control: ControlBase = None
patches = None
for x in to_batch:
o = x
p = o[0]
input_x.append(p.input_x)
mult.append(p.mult)
c.append(p.conditioning)
area.append(p.area)
cond_or_uncond.append(o[1])
uuids.append(p.uuid)
control = p.control
patches = p.patches
batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x).to(device)
c = cond_cat(c, device=device)
timestep_ = torch.cat([timestep.to(device)] * batch_chunks)
transformer_options = model_current.current_patcher.apply_hooks(hooks=hooks)
if 'transformer_options' in model_options:
transformer_options = comfy.patcher_extension.merge_nested_dicts(transformer_options,
model_options['transformer_options'],
copy_dict1=False)
if patches is not None:
transformer_options["patches"] = comfy.patcher_extension.merge_nested_dicts(
transformer_options.get("patches", {}),
patches
)
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
transformer_options["uuids"] = uuids[:]
transformer_options["sigmas"] = timestep.to(device)
transformer_options["sample_sigmas"] = transformer_options["sample_sigmas"].to(device)
transformer_options["multigpu_thread_device"] = device
cast_transformer_options(transformer_options, device=device)
c['transformer_options'] = transformer_options
if control is not None:
device_control = control.get_instance_for_device(device)
c['control'] = device_control.get_control(input_x, timestep_, c, len(cond_or_uncond), transformer_options)
if 'model_function_wrapper' in model_options:
output = model_options['model_function_wrapper'](model_current.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).to(output_device).chunk(batch_chunks)
else:
output = model_current.apply_model(input_x, timestep_, **c).to(output_device).chunk(batch_chunks)
# TODO: non-NVIDIA support -- the `.to(output_device)` copies
# above are async on CUDA, so the main thread's aggregation
# could race with in-flight transfers. CUDA-only QA has not
# surfaced this in practice, but before extending multigpu
# beyond NVIDIA add a `torch.cuda.synchronize(output_device)`
# here (guarded by `output_device.type == "cuda"`).
results.append(thread_result(output, mult, area, batch_chunks, cond_or_uncond))
except Exception as e:
results.append(thread_result(None, None, None, None, None, error=e))
raise
def _handle_batch_pooled(device, batch_tuple):
worker_results = []
_handle_batch(device, batch_tuple, worker_results)
return worker_results
results: list[thread_result] = []
thread_pool: comfy.multigpu.MultiGPUThreadPool = model_options.get("multigpu_thread_pool")
# Submit all GPU work to pool threads
pool_devices = []
for device, batch_tuple in device_batched_hooked_to_run.items():
if thread_pool is not None:
thread_pool.submit(device, _handle_batch_pooled, device, batch_tuple)
pool_devices.append(device)
else:
# Fallback: no pool, run everything on main thread
_handle_batch(device, batch_tuple, results)
# Collect results from pool workers
for device in pool_devices:
worker_results, error = thread_pool.get_result(device)
if error is not None:
raise error
results.extend(worker_results)
for output, mult, area, batch_chunks, cond_or_uncond, error in results:
if error is not None:
raise error
for o in range(batch_chunks):
cond_index = cond_or_uncond[o]
a = area[o]
if a is None:
out_conds[cond_index] += output[o] * mult[o]
out_counts[cond_index] += mult[o]
else:
out_c = out_conds[cond_index]
out_cts = out_counts[cond_index]
dims = len(a) // 2
for i in range(dims):
out_c = out_c.narrow(i + 2, a[i + dims], a[i])
out_cts = out_cts.narrow(i + 2, a[i + dims], a[i])
out_c += output[o] * mult[o]
out_cts += mult[o]
for i in range(len(out_conds)):
out_conds[i] /= out_counts[i]
return out_conds
def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove def calc_cond_uncond_batch(model, cond, uncond, x_in, timestep, model_options): #TODO: remove
logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.") logging.warning("WARNING: The comfy.samplers.calc_cond_uncond_batch function is deprecated please use the calc_cond_batch one instead.")
return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options)) return tuple(calc_cond_batch(model, [cond, uncond], x_in, timestep, model_options))
@ -643,12 +885,21 @@ def calculate_start_end_timesteps(model, conds):
def pre_run_control(model, conds): def pre_run_control(model, conds):
s = model.model_sampling s = model.model_sampling
# Per-device model lookup so multigpu control clones get the matching
# diffusion_model (e.g. QwenFunControlNet stashes it into extra_args).
device_models: dict = {}
patcher = getattr(model, "current_patcher", None)
if patcher is not None:
for p in patcher.get_additional_models_with_key("multigpu"):
device_models[p.load_device] = p.model
for t in range(len(conds)): for t in range(len(conds)):
x = conds[t] x = conds[t]
percent_to_timestep_function = lambda a: s.percent_to_sigma(a) percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
if 'control' in x: if 'control' in x:
x['control'].pre_run(model, percent_to_timestep_function) x['control'].pre_run(model, percent_to_timestep_function)
for device, device_cnet in x['control'].multigpu_clones.items():
device_cnet.pre_run(device_models.get(device, model), percent_to_timestep_function)
def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func): def apply_empty_x_to_equal_area(conds, uncond, name, uncond_fill_func):
cond_cnets = [] cond_cnets = []
@ -891,7 +1142,9 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
to_load_options = model_options.get("to_load_options", None) to_load_options = model_options.get("to_load_options", None)
if to_load_options is None: if to_load_options is None:
return return
cast_transformer_options(to_load_options, device, dtype)
def cast_transformer_options(transformer_options: dict[str], device=None, dtype=None):
casts = [] casts = []
if device is not None: if device is not None:
casts.append(device) casts.append(device)
@ -900,18 +1153,17 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
# if nothing to apply, do nothing # if nothing to apply, do nothing
if len(casts) == 0: if len(casts) == 0:
return return
# try to call .to on patches # try to call .to on patches
if "patches" in to_load_options: if "patches" in transformer_options:
patches = to_load_options["patches"] patches = transformer_options["patches"]
for name in patches: for name in patches:
patch_list = patches[name] patch_list = patches[name]
for i in range(len(patch_list)): for i in range(len(patch_list)):
if hasattr(patch_list[i], "to"): if hasattr(patch_list[i], "to"):
for cast in casts: for cast in casts:
patch_list[i] = patch_list[i].to(cast) patch_list[i] = patch_list[i].to(cast)
if "patches_replace" in to_load_options: if "patches_replace" in transformer_options:
patches = to_load_options["patches_replace"] patches = transformer_options["patches_replace"]
for name in patches: for name in patches:
patch_list = patches[name] patch_list = patches[name]
for k in patch_list: for k in patch_list:
@ -921,8 +1173,8 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
# try to call .to on any wrappers/callbacks # try to call .to on any wrappers/callbacks
wrappers_and_callbacks = ["wrappers", "callbacks"] wrappers_and_callbacks = ["wrappers", "callbacks"]
for wc_name in wrappers_and_callbacks: for wc_name in wrappers_and_callbacks:
if wc_name in to_load_options: if wc_name in transformer_options:
wc: dict[str, list] = to_load_options[wc_name] wc: dict[str, list] = transformer_options[wc_name]
for wc_dict in wc.values(): for wc_dict in wc.values():
for wc_list in wc_dict.values(): for wc_list in wc_dict.values():
for i in range(len(wc_list)): for i in range(len(wc_list)):
@ -930,7 +1182,6 @@ def cast_to_load_options(model_options: dict[str], device=None, dtype=None):
for cast in casts: for cast in casts:
wc_list[i] = wc_list[i].to(cast) wc_list[i] = wc_list[i].to(cast)
class CFGGuider: class CFGGuider:
def __init__(self, model_patcher: ModelPatcher): def __init__(self, model_patcher: ModelPatcher):
self.model_patcher = model_patcher self.model_patcher = model_patcher
@ -985,16 +1236,32 @@ class CFGGuider:
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
device = self.model_patcher.load_device device = self.model_patcher.load_device
multigpu_patchers = comfy.sampler_helpers.prepare_model_patcher_multigpu_clones(self.model_patcher, self.loaded_models, self.model_options)
# Create persistent thread pool for all GPU devices (main + extras)
if multigpu_patchers:
extra_devices = [p.load_device for p in multigpu_patchers]
all_devices = [device] + extra_devices
self.model_options["multigpu_thread_pool"] = comfy.multigpu.MultiGPUThreadPool(all_devices)
with comfy.model_management.cuda_device_context(device):
try:
noise = noise.to(device=device, dtype=torch.float32) noise = noise.to(device=device, dtype=torch.float32)
latent_image = latent_image.to(device=device, dtype=torch.float32) latent_image = latent_image.to(device=device, dtype=torch.float32)
sigmas = sigmas.to(device) sigmas = sigmas.to(device)
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype()) cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
try:
self.model_patcher.pre_run() self.model_patcher.pre_run()
for multigpu_patcher in multigpu_patchers:
multigpu_patcher.pre_run()
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes) output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
finally: finally:
thread_pool = self.model_options.pop("multigpu_thread_pool", None)
if thread_pool is not None:
thread_pool.shutdown()
self.model_patcher.cleanup() self.model_patcher.cleanup()
for multigpu_patcher in multigpu_patchers:
multigpu_patcher.cleanup()
comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models) comfy.sampler_helpers.cleanup_models(self.conds, self.loaded_models)
del self.inner_model del self.inner_model

View File

@ -1,4 +1,3 @@
from __future__ import annotations
import json import json
import torch import torch
from enum import Enum from enum import Enum
@ -50,6 +49,7 @@ import comfy.text_encoders.lt
import comfy.text_encoders.hunyuan_video import comfy.text_encoders.hunyuan_video
import comfy.text_encoders.cosmos import comfy.text_encoders.cosmos
import comfy.text_encoders.lumina2 import comfy.text_encoders.lumina2
import comfy.text_encoders.pixeldit
import comfy.text_encoders.wan import comfy.text_encoders.wan
import comfy.text_encoders.hidream import comfy.text_encoders.hidream
import comfy.text_encoders.ace import comfy.text_encoders.ace
@ -69,6 +69,7 @@ import comfy.text_encoders.ernie
import comfy.text_encoders.gemma4 import comfy.text_encoders.gemma4
import comfy.text_encoders.cogvideo import comfy.text_encoders.cogvideo
import comfy.text_encoders.sa3 import comfy.text_encoders.sa3
import comfy.text_encoders.gpt_oss
import comfy.model_patcher import comfy.model_patcher
import comfy.lora import comfy.lora
@ -335,12 +336,14 @@ class CLIP:
self.cond_stage_model.set_clip_options({"projected_pooled": False}) self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model(tokens) self.load_model(tokens)
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) device = self.patcher.load_device
self.cond_stage_model.set_clip_options({"execution_device": device})
all_hooks.reset() all_hooks.reset()
self.patcher.patch_hooks(None) self.patcher.patch_hooks(None)
if show_pbar: if show_pbar:
pbar = ProgressBar(len(scheduled_keyframes)) pbar = ProgressBar(len(scheduled_keyframes))
with model_management.cuda_device_context(device):
for scheduled_opts in scheduled_keyframes: for scheduled_opts in scheduled_keyframes:
t_range = scheduled_opts[0] t_range = scheduled_opts[0]
# don't bother encoding any conds outside of start_percent and end_percent bounds # don't bother encoding any conds outside of start_percent and end_percent bounds
@ -383,8 +386,12 @@ class CLIP:
self.cond_stage_model.set_clip_options({"projected_pooled": False}) self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model(tokens) self.load_model(tokens)
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) device = self.patcher.load_device
self.cond_stage_model.set_clip_options({"execution_device": device})
with model_management.cuda_device_context(device):
o = self.cond_stage_model.encode_token_weights(tokens) o = self.cond_stage_model.encode_token_weights(tokens)
cond, pooled = o[:2] cond, pooled = o[:2]
if return_dict: if return_dict:
out = {"cond": cond, "pooled_output": pooled} out = {"cond": cond, "pooled_output": pooled}
@ -446,8 +453,11 @@ class CLIP:
self.cond_stage_model.reset_clip_options() self.cond_stage_model.reset_clip_options()
self.load_model(tokens) self.load_model(tokens)
device = self.patcher.load_device
self.cond_stage_model.set_clip_options({"layer": None}) self.cond_stage_model.set_clip_options({"layer": None})
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) self.cond_stage_model.set_clip_options({"execution_device": device})
with model_management.cuda_device_context(device):
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty) return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed, presence_penalty=presence_penalty)
def decode(self, token_ids, skip_special_tokens=True): def decode(self, token_ids, skip_special_tokens=True):
@ -1026,6 +1036,8 @@ class VAE:
do_tile = False do_tile = False
if self.latent_dim == 2 and samples_in.ndim == 5: if self.latent_dim == 2 and samples_in.ndim == 5:
samples_in = samples_in[:, :, 0] samples_in = samples_in[:, :, 0]
with model_management.cuda_device_context(self.device):
try: try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
@ -1087,6 +1099,7 @@ class VAE:
if overlap is not None: if overlap is not None:
args["overlap"] = overlap args["overlap"] = overlap
with model_management.cuda_device_context(self.device):
if dims == 1 or self.extra_1d_channel is not None: if dims == 1 or self.extra_1d_channel is not None:
args.pop("tile_y") args.pop("tile_y")
output = self.decode_tiled_1d(samples, **args) output = self.decode_tiled_1d(samples, **args)
@ -1113,6 +1126,8 @@ class VAE:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
else: else:
pixel_samples = pixel_samples.unsqueeze(2) pixel_samples = pixel_samples.unsqueeze(2)
with model_management.cuda_device_context(self.device):
try: try:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
@ -1176,6 +1191,7 @@ class VAE:
if overlap is not None: if overlap is not None:
args["overlap"] = overlap args["overlap"] = overlap
with model_management.cuda_device_context(self.device):
if dims == 1: if dims == 1:
args.pop("tile_y") args.pop("tile_y")
samples = self.encode_tiled_1d(pixel_samples, **args) samples = self.encode_tiled_1d(pixel_samples, **args)
@ -1269,6 +1285,8 @@ class CLIPType(Enum):
FLUX2 = 25 FLUX2 = 25
LONGCAT_IMAGE = 26 LONGCAT_IMAGE = 26
COGVIDEOX = 27 COGVIDEOX = 27
LENS = 28
PIXELDIT = 29
@ -1321,6 +1339,7 @@ class TEModel(Enum):
GEMMA_4_E2B = 30 GEMMA_4_E2B = 30
GEMMA_4_31B = 31 GEMMA_4_31B = 31
T5_GEMMA = 32 T5_GEMMA = 32
GPT_OSS_20B = 33
def detect_te_model(sd): def detect_te_model(sd):
@ -1362,6 +1381,9 @@ def detect_te_model(sd):
else: else:
return TEModel.GEMMA_3_4B return TEModel.GEMMA_3_4B
return TEModel.GEMMA_2_2B return TEModel.GEMMA_2_2B
# Must precede the Qwen2.5-7B k_proj.bias=512 check (GPT-OSS also has 8*64=512).
if "layers.0.self_attn.sinks" in sd and "layers.0.mlp.experts.gate_up_proj.weight" in sd:
return TEModel.GPT_OSS_20B
if 'model.layers.0.self_attn.k_proj.bias' in sd: if 'model.layers.0.self_attn.k_proj.bias' in sd:
weight = sd['model.layers.0.self_attn.k_proj.bias'] weight = sd['model.layers.0.self_attn.k_proj.bias']
if weight.shape[0] == 256: if weight.shape[0] == 256:
@ -1508,6 +1530,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.tokenizer = variant.tokenizer clip_target.tokenizer = variant.tokenizer
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None) tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
elif te_model == TEModel.GEMMA_2_2B: elif te_model == TEModel.GEMMA_2_2B:
if clip_type == CLIPType.PIXELDIT:
clip_target.clip = comfy.text_encoders.pixeldit.pixeldit_te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.pixeldit.PixelDiTGemma2Tokenizer
else:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data)) clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
@ -1544,6 +1570,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.flux.flux2_te(**llama_detect(clip_data), pruned=te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2) clip_target.clip = comfy.text_encoders.flux.flux2_te(**llama_detect(clip_data), pruned=te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2)
clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer
tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None) tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None)
elif te_model == TEModel.GPT_OSS_20B:
clip_target.clip = comfy.text_encoders.gpt_oss.lens_te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.gpt_oss.LensTokenizer
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
elif te_model == TEModel.QWEN3_4B: elif te_model == TEModel.QWEN3_4B:
if clip_type == CLIPType.FLUX or clip_type == CLIPType.FLUX2: if clip_type == CLIPType.FLUX or clip_type == CLIPType.FLUX2:
clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_4b") clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_4b")
@ -1710,12 +1740,52 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic) out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
if out is None: if out is None:
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd))) raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
if output_model and out[0] is not None: if out[0] is not None:
out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options)) out[0].cached_patcher_init = (load_checkpoint_guess_config, (ckpt_path, False, False, False, embedding_directory, output_model, model_options, te_model_options), 0)
if output_clip and out[1] is not None: # Register reload factories for the CLIP and VAE produced by the same checkpoint so
out[1].patcher.cached_patcher_init = (load_checkpoint_guess_config_clip_only, (ckpt_path, embedding_directory, model_options, te_model_options)) # ModelPatcher.deepclone_multigpu can spawn per-device copies (Select{CLIP,VAE}Device,
# MultiGPU work-units, etc.) without falling back to copy.deepcopy of an
# already-loaded module.
if out[1] is not None and getattr(out[1], "patcher", None) is not None:
out[1].patcher.cached_patcher_init = (load_checkpoint_clip_patcher, (ckpt_path, embedding_directory, model_options, te_model_options))
if out[2] is not None and getattr(out[2], "patcher", None) is not None:
out[2].patcher.cached_patcher_init = (load_checkpoint_vae_patcher, (ckpt_path, embedding_directory, model_options, te_model_options))
return out return out
def load_checkpoint_clip_patcher(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
"""Reload only the CLIP patcher from a checkpoint. Used as the cached_patcher_init
factory for the CLIP returned by load_checkpoint_guess_config."""
_, clip, _, _ = load_checkpoint_guess_config(
ckpt_path,
output_vae=False,
output_clip=True,
output_clipvision=False,
embedding_directory=embedding_directory,
output_model=False,
model_options=model_options,
te_model_options=te_model_options,
disable_dynamic=disable_dynamic,
)
return clip.patcher
def load_checkpoint_vae_patcher(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
"""Reload only the VAE patcher from a checkpoint. Used as the cached_patcher_init
factory for the VAE returned by load_checkpoint_guess_config."""
_, _, vae, _ = load_checkpoint_guess_config(
ckpt_path,
output_vae=True,
output_clip=False,
output_clipvision=False,
embedding_directory=embedding_directory,
output_model=False,
model_options=model_options,
te_model_options=te_model_options,
disable_dynamic=disable_dynamic,
)
return vae.patcher
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False): def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False, model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False,
embedding_directory=embedding_directory, embedding_directory=embedding_directory,
@ -1742,7 +1812,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix) parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix) weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
load_device = model_management.get_torch_device() load_device = model_options.get("load_device", model_management.get_torch_device())
custom_operations = model_options.get("custom_operations", None) custom_operations = model_options.get("custom_operations", None)
if custom_operations is None: if custom_operations is None:
@ -1782,13 +1852,15 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype) inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device) model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
model_patcher = ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device()) offload_device = model_options.get("offload_device", model_management.unet_offload_device())
model_patcher = ModelPatcher(model, load_device=load_device, offload_device=offload_device)
model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic()) model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic())
if output_vae: if output_vae:
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True) vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
vae_sd = model_config.process_vae_state_dict(vae_sd) vae_sd = model_config.process_vae_state_dict(vae_sd)
vae = VAE(sd=vae_sd, metadata=metadata) vae_device = model_options.get("load_device", None)
vae = VAE(sd=vae_sd, metadata=metadata, device=vae_device)
if output_clip: if output_clip:
if te_model_options.get("custom_operations", None) is None: if te_model_options.get("custom_operations", None) is None:
@ -1872,7 +1944,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
parameters = comfy.utils.calculate_parameters(sd) parameters = comfy.utils.calculate_parameters(sd)
weight_dtype = comfy.utils.weight_dtype(sd) weight_dtype = comfy.utils.weight_dtype(sd)
load_device = model_management.get_torch_device() load_device = model_options.get("load_device", model_management.get_torch_device())
model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata) model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata)
if model_config is not None: if model_config is not None:
@ -1897,7 +1969,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
else: else:
logging.warning("{} {}".format(diffusers_keys[k], k)) logging.warning("{} {}".format(diffusers_keys[k], k))
offload_device = model_management.unet_offload_device() offload_device = model_options.get("offload_device", model_management.unet_offload_device())
unet_weight_dtype = list(model_config.supported_inference_dtypes) unet_weight_dtype = list(model_config.supported_inference_dtypes)
if model_config.quant_config is not None: if model_config.quant_config is not None:
weight_dtype = None weight_dtype = None
@ -1939,6 +2011,26 @@ def load_diffusion_model(unet_path, model_options={}, disable_dynamic=False):
model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options)) model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options))
return model return model
def load_vae_patcher(vae_path, metadata=None, device=None, disable_dynamic=False):
"""Reload a disk-backed VAE from ``vae_path`` and return its patcher.
Used as the ``cached_patcher_init`` factory on ``VAE.patcher`` so
:meth:`comfy.model_patcher.ModelPatcher.deepclone_multigpu` can produce a
fresh, untainted VAE patcher (no inherited per-device load state, no
in-place quantization fallout) for multigpu work-units and the
SelectVAEDevice node. The optional ``device`` matches the source loader's
VAE initialization path; the deepclone's ``load_device`` still controls
where the cloned patcher is targeted.
"""
if metadata is None:
sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True)
else:
sd = comfy.utils.load_torch_file(vae_path)
vae = VAE(sd=sd, metadata=metadata, device=device)
vae.throw_exception_if_invalid()
return vae.patcher
def load_unet(unet_path, dtype=None): def load_unet(unet_path, dtype=None):
logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model") logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
return load_diffusion_model(unet_path, model_options={"dtype": dtype}) return load_diffusion_model(unet_path, model_options={"dtype": dtype})

View File

@ -30,6 +30,7 @@ import comfy.text_encoders.longcat_image
import comfy.text_encoders.ernie import comfy.text_encoders.ernie
import comfy.text_encoders.cogvideo import comfy.text_encoders.cogvideo
import comfy.text_encoders.hidream_o1 import comfy.text_encoders.hidream_o1
import comfy.text_encoders.pixeldit
from . import supported_models_base from . import supported_models_base
from . import latent_formats from . import latent_formats
@ -829,6 +830,50 @@ class Flux2(Flux):
return None return None
class Lens(supported_models_base.BASE):
"""Microsoft Lens (3.8B dual-stream MMDiT, GPT-OSS-20B text features, Flux2 VAE)."""
unet_config = {
"image_model": "lens",
}
sampling_settings = {
"shift": 1.829, # Default mu for 1440x1440 (and any seq_len > 4300
}
unet_extra_config = {}
latent_format = latent_formats.Flux2
memory_usage_factor = 4.0
supported_inference_dtypes = [torch.bfloat16, torch.float32] # fp16 causes NaNs
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def __init__(self, unet_config):
super().__init__(unet_config)
def get_model(self, state_dict, prefix="", device=None):
return model_base.Lens(self, model_type=model_base.ModelType.FLUX, device=device)
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
for hint in ("gpt_oss.transformer.", ""):
full_prefix = "{}{}".format(pref, hint)
if "{}layers.0.self_attn.sinks".format(full_prefix) in state_dict:
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, full_prefix)
return supported_models_base.ClipTarget(
comfy.text_encoders.gpt_oss.LensTokenizer,
comfy.text_encoders.gpt_oss.lens_te(**detect),
)
return supported_models_base.ClipTarget(
comfy.text_encoders.gpt_oss.LensTokenizer,
comfy.text_encoders.gpt_oss.lens_te(),
)
class GenmoMochi(supported_models_base.BASE): class GenmoMochi(supported_models_base.BASE):
unet_config = { unet_config = {
"image_model": "mochi_preview", "image_model": "mochi_preview",
@ -1159,6 +1204,72 @@ class ZImagePixelSpace(ZImage):
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
return model_base.ZImagePixelSpace(self, device=device) return model_base.ZImagePixelSpace(self, device=device)
class PixelDiTT2I(supported_models_base.BASE):
unet_config = {
"image_model": "pixeldit_t2i",
}
unet_extra_config = {}
sampling_settings = {
"shift": 4.0, # 1024px stage 3 default; 2.0 for 512px
}
latent_format = latent_formats.PixelDiTPixel
memory_usage_factor = 0.04
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
return model_base.PixelDiTT2I(self, device=device)
def process_unet_state_dict(self, state_dict):
# pixel_dim from pixel_embedder.proj.weight = (pixel_dim, in_channels); p2 derived per-weight from total // (6 * pixel_dim).
pixel_dim = next(v for k, v in state_dict.items() if k.endswith("pixel_embedder.proj.weight")).shape[0]
out = {}
marker = ".adaLN_modulation.0."
for k, v in state_dict.items():
if k.startswith("_repa_projector") or k.startswith("net_ema."):
continue
if k.startswith("core."):
k = k[len("core."):]
elif k.startswith("net."):
k = k[len("net."):]
if "pixel_blocks." in k and marker in k:
# Split into msa (chunks 0-2) and mlp (chunks 3-5) for the two-Linear PiTBlock to reduce peak VRAM
p2 = v.shape[0] // (6 * pixel_dim)
trail = v.shape[1:] # () for bias, (in_dim,) for weight
vv = v.view(p2, 6, pixel_dim, *trail)
base, suffix = k.split(marker)
out[f"{base}.adaLN_modulation_msa.{suffix}"] = vv[:, 0:3].reshape(3 * p2 * pixel_dim, *trail).contiguous()
out[f"{base}.adaLN_modulation_mlp.{suffix}"] = vv[:, 3:6].reshape(3 * p2 * pixel_dim, *trail).contiguous()
else:
out[k] = v
return out
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(
comfy.text_encoders.pixeldit.PixelDiTGemma2Tokenizer,
comfy.text_encoders.pixeldit.PixelDiTGemma2TE,
)
class PiD(PixelDiTT2I):
unet_config = {
"image_model": "pid",
}
sampling_settings = {
"shift": 1.5, # close approximation of the original distill 4 steps [0.999, 0.866, 0.634, 0.342, 0]
}
memory_usage_factor = 0.04
def get_model(self, state_dict, prefix="", device=None):
return model_base.PiD(self, device=device)
class WAN21_T2V(supported_models_base.BASE): class WAN21_T2V(supported_models_base.BASE):
unet_config = { unet_config = {
"image_model": "wan2.1", "image_model": "wan2.1",
@ -2069,6 +2180,8 @@ models = [
CosmosI2VPredict2, CosmosI2VPredict2,
ZImagePixelSpace, ZImagePixelSpace,
ZImage, ZImage,
PiD,
PixelDiTT2I,
Lumina2, Lumina2,
WAN22_T2V, WAN22_T2V,
WAN21_CausalAR_T2V, WAN21_CausalAR_T2V,
@ -2096,6 +2209,7 @@ models = [
Omnigen2, Omnigen2,
QwenImage, QwenImage,
Flux2, Flux2,
Lens,
Kandinsky5Image, Kandinsky5Image,
Kandinsky5, Kandinsky5,
Anima, Anima,

View File

@ -0,0 +1,600 @@
"""GPT-OSS text encoder for Lens."""
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Any, List, Optional, Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.ops
from comfy import sd1_clip
from comfy.ldm.modules.attention import TORCH_HAS_GQA, optimized_attention_for_device
from comfy.text_encoders.llama import RMSNorm, apply_rope
@dataclass
class GptOss20BConfig:
vocab_size: int = 201088
hidden_size: int = 2880
intermediate_size: int = 2880
num_hidden_layers: int = 24
num_attention_heads: int = 64
num_key_value_heads: int = 8
head_dim: int = 64
num_local_experts: int = 32
num_experts_per_tok: int = 4
sliding_window: int = 128
original_max_position_embeddings: int = 4096
rope_theta: float = 150000.0
rope_factor: float = 32.0
rope_beta_fast: float = 32.0
rope_beta_slow: float = 1.0
rope_truncate: bool = False
rms_norm_eps: float = 1e-5
attention_bias: bool = True
layer_types: Optional[List[str]] = None
moe_alpha: float = 1.702
moe_limit: float = 7.0
def __post_init__(self):
if self.layer_types is None:
self.layer_types = [
"sliding_attention" if (i + 1) % 2 else "full_attention"
for i in range(self.num_hidden_layers)
]
def _yarn_inv_freq(head_dim: int, base: float, factor: float, beta_fast: float, beta_slow: float,
original_max_position_embeddings: int, truncate: bool, device=None) -> tuple[torch.Tensor, float]:
"""YARN inv_freq + attention scaling (matches transformers)."""
dim = head_dim
def find_correction_dim(num_rotations: float) -> float:
return (dim * math.log(original_max_position_embeddings / (num_rotations * 2 * math.pi))) / (
2 * math.log(base)
)
def find_correction_range() -> tuple[float, float]:
low = find_correction_dim(beta_fast)
high = find_correction_dim(beta_slow)
if truncate:
low = math.floor(low)
high = math.ceil(high)
return max(low, 0), min(high, dim - 1)
def linear_ramp_factor(min_: float, max_: float, n: int) -> torch.Tensor:
if min_ == max_:
max_ += 0.001
linear = (torch.arange(n, dtype=torch.float32, device=device) - min_) / (max_ - min_)
return torch.clamp(linear, 0, 1)
def get_mscale(scale: float) -> float:
if scale <= 1:
return 1.0
return 0.1 * math.log(scale) + 1.0
attention_scaling = get_mscale(factor)
pos_freqs = base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
low, high = find_correction_range()
extrap_factor = 1 - linear_ramp_factor(low, high, dim // 2)
inv_freq = inv_freq_interpolation * (1 - extrap_factor) + inv_freq_extrapolation * extrap_factor
return inv_freq, attention_scaling
def _build_freqs_cis(inv_freq: torch.Tensor, attention_scaling: float, position_ids: torch.Tensor, dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
inv_freq_e = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
pos_e = position_ids[:, None, :].float()
freqs = (inv_freq_e @ pos_e).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = (emb.cos() * attention_scaling).to(dtype).unsqueeze(1)
sin = (emb.sin() * attention_scaling).to(dtype).unsqueeze(1)
sin_split = sin.shape[-1] // 2
return cos, sin[..., :sin_split], -sin[..., sin_split:]
def _attention_with_sinks(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, sinks: torch.Tensor,
attention_mask: Optional[torch.Tensor], num_heads: int, num_kv_groups: int) -> torch.Tensor:
"""Attention with per-head sinks.
Sinks add a learned term to each row's softmax denominator but contribute
nothing to the output. We fake this by appending one zero k/v position and
putting the sink logit in the mask at that column.
"""
if num_kv_groups > 1 and not TORCH_HAS_GQA:
k = k.repeat_interleave(num_kv_groups, dim=1)
v = v.repeat_interleave(num_kv_groups, dim=1)
B, _, S_q, D = q.shape
H_kv = k.shape[1]
S_kv = k.shape[-2]
k = torch.cat([k, k.new_zeros(B, H_kv, 1, D)], dim=-2)
v = torch.cat([v, v.new_zeros(B, H_kv, 1, D)], dim=-2)
sinks_col = sinks.to(q.dtype).view(1, num_heads, 1, 1).expand(B, num_heads, S_q, 1)
if attention_mask is not None:
mask_left = attention_mask[..., :S_kv].expand(B, num_heads, S_q, S_kv)
else:
mask_left = q.new_zeros(B, num_heads, S_q, S_kv)
mask = torch.cat([mask_left, sinks_col], dim=-1)
op = optimized_attention_for_device(q.device, mask=True, small_input=True)
return op(q, k, v, num_heads, mask=mask, skip_reshape=True, enable_gqa=True)
class GptOssAttention(nn.Module):
def __init__(self, config: GptOss20BConfig, layer_idx: int, device=None, dtype=None, ops: Any = None):
super().__init__()
self.layer_idx = layer_idx
self.layer_type = config.layer_types[layer_idx]
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.num_kv_groups = self.num_heads // self.num_kv_heads
self.head_dim = config.head_dim
self.hidden_size = config.hidden_size
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
bias = config.attention_bias
self.q_proj = ops.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=bias, device=device, dtype=dtype)
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=bias, device=device, dtype=dtype)
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=bias, device=device, dtype=dtype)
self.o_proj = ops.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=bias, device=device, dtype=dtype)
self.sinks = nn.Parameter(torch.empty(self.num_heads, device=device, dtype=dtype))
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], freqs_cis) -> torch.Tensor:
B, S, _ = hidden_states.shape
q = self.q_proj(hidden_states).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
q, k = apply_rope(q, k, freqs_cis)
out = _attention_with_sinks(q, k, v, self.sinks, attention_mask, self.num_heads, self.num_kv_groups)
return self.o_proj(out)
# Mixture of Experts
class GptOssTopKRouter(nn.Module):
def __init__(self, config: GptOss20BConfig, device=None, dtype=None):
super().__init__()
self.top_k = config.num_experts_per_tok
self.num_experts = config.num_local_experts
self.weight = nn.Parameter(torch.empty(config.num_local_experts, config.hidden_size, device=device, dtype=dtype))
self.bias = nn.Parameter(torch.empty(config.num_local_experts, device=device, dtype=dtype))
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
weight = comfy.ops.cast_to_input(self.weight, hidden_states, copy=False)
bias = comfy.ops.cast_to_input(self.bias, hidden_states, copy=False)
logits = F.linear(hidden_states, weight, bias)
top_vals, top_idx = torch.topk(logits, self.top_k, dim=-1)
# Softmax over top-k slice only
scores = F.softmax(top_vals, dim=-1, dtype=top_vals.dtype)
return scores, top_idx
class GptOssExperts(nn.Module):
def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None):
super().__init__()
self.num_experts = config.num_local_experts
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.alpha = config.moe_alpha
self.limit = config.moe_limit
E = self.num_experts
H = self.hidden_size
I = self.intermediate_size
self.gate_up_proj = ops.MoEExperts(num_experts=E, in_features=H, out_features=2 * I, bias=True, device=device, dtype=dtype)
self.down_proj = ops.MoEExperts(num_experts=E, in_features=I, out_features=H, bias=True, device=device, dtype=dtype)
def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor:
gate = gate_up[..., ::2]
up = gate_up[..., 1::2]
gate = gate.clamp(max=self.limit)
up = up.clamp(min=-self.limit, max=self.limit)
glu = gate * torch.sigmoid(gate * self.alpha)
return torch.addcmul(glu, up, glu)
def forward(self, hidden_states: torch.Tensor, router_indices: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor:
N = hidden_states.shape[0]
top_k = router_indices.shape[-1]
H = hidden_states.shape[-1]
per_pair = torch.zeros((N * top_k, H), dtype=hidden_states.dtype, device=hidden_states.device)
expert_mask = F.one_hot(router_indices, num_classes=self.num_experts).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
with self.gate_up_proj.bank_resident(hidden_states) as gate_up_bank, \
self.down_proj.bank_resident(hidden_states) as down_bank:
for ei in expert_hit:
expert_idx = int(ei.item())
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
current = hidden_states[token_idx]
gate_up = gate_up_bank.expert_linear(current, expert_idx)
gated = self._apply_gate(gate_up)
expert_out = down_bank.expert_linear(gated, expert_idx)
weighted = expert_out * routing_weights[token_idx, top_k_pos, None]
flat_idx = token_idx * top_k + top_k_pos
per_pair[flat_idx] = weighted.to(per_pair.dtype)
return per_pair.view(N, top_k, H).sum(dim=1)
class GptOssMLP(nn.Module):
def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None):
super().__init__()
self.router = GptOssTopKRouter(config, device=device, dtype=dtype)
self.experts = GptOssExperts(config, device=device, dtype=dtype, ops=ops)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
B, S, H = hidden_states.shape
flat = hidden_states.reshape(-1, H)
scores, idx = self.router(flat)
out = self.experts(flat, idx, scores)
return out.reshape(B, S, H)
# Decoder layer + model
class GptOssDecoderLayer(nn.Module):
def __init__(self, config: GptOss20BConfig, layer_idx: int, device=None, dtype=None, ops: Any = None):
super().__init__()
self.self_attn = GptOssAttention(config, layer_idx, device=device, dtype=dtype, ops=ops)
self.mlp = GptOssMLP(config, device=device, dtype=dtype, ops=ops)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
self.layer_type = config.layer_types[layer_idx]
def forward(self, x: torch.Tensor, attention_masks: dict[str, Optional[torch.Tensor]], freqs_cis) -> torch.Tensor:
residual = x
x = self.input_layernorm(x)
x = self.self_attn(x, attention_masks[self.layer_type], freqs_cis)
x = residual + x
residual = x
x = self.post_attention_layernorm(x)
x = self.mlp(x)
x = residual + x
return x
def _make_full_causal_mask(B: int, S: int, key_padding_mask: Optional[torch.Tensor], dtype, device):
neg = torch.finfo(dtype).min
mask = torch.full((S, S), neg, dtype=dtype, device=device).triu_(1)
mask = mask.unsqueeze(0).unsqueeze(0).expand(B, 1, S, S).contiguous()
if key_padding_mask is not None:
kp = key_padding_mask.to(dtype=dtype)
kp = (1.0 - kp).reshape(B, 1, 1, S) * neg
mask = mask + kp
return mask
def _make_sliding_causal_mask(B: int, S: int, window: int, key_padding_mask: Optional[torch.Tensor], dtype, device):
neg = torch.finfo(dtype).min
i = torch.arange(S, device=device).view(-1, 1)
j = torch.arange(S, device=device).view(1, -1)
keep = (j <= i) & (j > i - window)
mask = torch.where(keep, torch.zeros((), dtype=dtype, device=device), torch.full((), neg, dtype=dtype, device=device))
mask = mask.unsqueeze(0).unsqueeze(0).expand(B, 1, S, S).contiguous()
if key_padding_mask is not None:
kp = key_padding_mask.to(dtype=dtype)
kp = (1.0 - kp).reshape(B, 1, 1, S) * neg
mask = mask + kp
return mask
class GptOssModel(nn.Module):
"""GPT-OSS decoder with multi-layer hidden-state capture + early exit."""
def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None):
super().__init__()
self.config = config
self.dtype = dtype
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
self.layers = nn.ModuleList(
[
GptOssDecoderLayer(config, i, device=device, dtype=dtype, ops=ops)
for i in range(config.num_hidden_layers)
]
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
# Always build on CPU so the buffer survives meta-device construction.
inv_freq, attn_scaling = _yarn_inv_freq(
head_dim=config.head_dim,
base=config.rope_theta,
factor=config.rope_factor,
beta_fast=config.rope_beta_fast,
beta_slow=config.rope_beta_slow,
original_max_position_embeddings=config.original_max_position_embeddings,
truncate=config.rope_truncate,
device=torch.device("cpu"),
)
self.register_buffer("rope_inv_freq", inv_freq, persistent=False)
self.rope_attention_scaling = float(attn_scaling)
@property
def num_layers(self) -> int:
return self.config.num_hidden_layers
def get_input_embeddings(self):
return self.embed_tokens
def _build_attention_masks(self, B: int, S: int, attention_mask: Optional[torch.Tensor], dtype: torch.dtype, device,
) -> dict[str, torch.Tensor]:
full = _make_full_causal_mask(B, S, attention_mask, dtype, device)
masks = {"full_attention": full}
if any(t == "sliding_attention" for t in self.config.layer_types):
masks["sliding_attention"] = _make_sliding_causal_mask(
B, S, self.config.sliding_window, attention_mask, dtype, device
)
return masks
def forward(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None,
capture_layers: Optional[Sequence[int]] = None) -> dict[str, Any]:
B, S = input_ids.shape
device = input_ids.device
dtype = self.dtype
hidden_states = self.embed_tokens(input_ids, out_dtype=dtype)
position_ids = torch.arange(S, device=device).unsqueeze(0).expand(B, -1)
freqs_cis = _build_freqs_cis(self.rope_inv_freq.to(device=device), self.rope_attention_scaling, position_ids, dtype)
attn_masks = self._build_attention_masks(B, S, attention_mask, dtype, device)
capture_layers = list(capture_layers) if capture_layers else None
if capture_layers:
max_layer = max(capture_layers)
wanted = {idx: pos for pos, idx in enumerate(capture_layers)}
captured: List[Optional[torch.Tensor]] = [None] * len(capture_layers)
else:
max_layer = self.config.num_hidden_layers - 1
wanted = None
captured = None
for i, layer in enumerate(self.layers):
hidden_states = layer(hidden_states, attn_masks, freqs_cis)
if wanted is not None and i in wanted:
captured[wanted[i]] = hidden_states
if i >= max_layer:
break
if captured is not None:
return {"hidden_states": captured}
return {"last_hidden_state": self.norm(hidden_states)}
# Lens chat-template constants (verbatim from the reference pipeline).
_LENS_CHAT_SYSTEM = (
"Describe the image by detailing the color, shape, size, texture, "
"quantity, text, spatial relationships of the objects and background."
)
_LENS_CHAT_ASSISTANT_THINKING = "Need to generate one image according to the description."
LENS_TXT_OFFSET = 97
LENS_SELECTED_LAYERS = (5, 11, 17, 23)
LENS_MAX_TOKENS = 512
# The reference GPT-OSS Harmony template injects today's date here
_LENS_CHAT_DATE = "2026-05-23"
def _lens_render_chat(prompt: str) -> str:
"""Render the Lens prompt in GPT-OSS Harmony format."""
return (
f"<|start|>system<|message|>"
f"You are ChatGPT, a large language model trained by OpenAI.\n"
f"Knowledge cutoff: 2024-06\n"
f"Current date: {_LENS_CHAT_DATE}\n\n"
f"Reasoning: medium\n\n"
f"# Valid channels: analysis, commentary, final. "
f"Channel must be included for every message.<|end|>"
f"<|start|>developer<|message|># Instructions\n\n"
f"{_LENS_CHAT_SYSTEM}\n\n<|end|>"
f"<|start|>user<|message|>{prompt}<|end|>"
f"<|start|>assistant<|channel|>analysis<|message|>"
f"{_LENS_CHAT_ASSISTANT_THINKING}<|end|>"
f"<|start|>assistant<|channel|>final<|message|>"
)
# GPT-OSS-20B fixed token IDs (from the tokenizer's added-tokens table).
_LENS_PAD_TOKEN_ID = 199999 # <|endoftext|>
class _GptOssRawTokenizer:
"""Raw ``tokenizers.Tokenizer`` wrapper.
The tokenizer JSON ships as a byte tensor inside the encoder checkpoint
(``tokenizer_json`` key) rather than as a committed file. Extracted
it in ``sd.py`` and passes it here via ``tokenizer_data``.
"""
def __init__(self, tokenizer_json_bytes=None, **kwargs):
from tokenizers import Tokenizer
if isinstance(tokenizer_json_bytes, torch.Tensor):
tokenizer_json_bytes = bytes(tokenizer_json_bytes.tolist())
if tokenizer_json_bytes is None:
raise ValueError(
"Lens tokenizer requires the ``tokenizer_json`` byte tensor in the "
"encoder state dict. Re-bundle the encoder via bundle_te.py so it "
"embeds the tokenizer."
)
self.tokenizer = Tokenizer.from_str(tokenizer_json_bytes.decode("utf-8"))
@classmethod
def from_pretrained(cls, tokenizer_data, **kwargs):
return cls(tokenizer_json_bytes=tokenizer_data, **kwargs)
def __call__(self, text):
return {"input_ids": self.tokenizer.encode(text, add_special_tokens=False).ids}
def get_vocab(self):
return self.tokenizer.get_vocab()
def convert_tokens_to_ids(self, tokens):
return [self.tokenizer.token_to_id(t) for t in tokens]
def decode(self, ids, **kwargs):
return self.tokenizer.decode(ids, skip_special_tokens=kwargs.get("skip_special_tokens", False))
class LensGptOssTokenizer(sd1_clip.SDTokenizer):
tokenizer_json_data = None
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_json = tokenizer_data.get("tokenizer_json", None)
self.tokenizer_json_data = tokenizer_json
super().__init__(
tokenizer_json,
embedding_directory=embedding_directory,
pad_with_end=False,
embedding_size=2880,
embedding_key="gpt_oss",
tokenizer_class=_GptOssRawTokenizer,
has_start_token=False,
has_end_token=False,
pad_to_max_length=False,
max_length=99999999,
min_length=1,
pad_left=False,
disable_weights=True,
tokenizer_data=tokenizer_data,
)
self.pad_token_id = _LENS_PAD_TOKEN_ID
def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs):
# Empty prompt -> empty list; encode_token_weights returns zeros (uncond).
if not text or not text.strip():
return [[]]
rendered = _lens_render_chat(text)
ids = self.tokenizer(rendered)["input_ids"]
if len(ids) > LENS_MAX_TOKENS:
ids = ids[:LENS_MAX_TOKENS]
return [[(int(t), 1.0) for t in ids]]
def state_dict(self):
if self.tokenizer_json_data is not None:
return {"tokenizer_json": self.tokenizer_json_data}
return {}
class LensTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(
embedding_directory=embedding_directory,
tokenizer_data=tokenizer_data,
name="gpt_oss",
tokenizer=LensGptOssTokenizer,
)
class LensGptOssClipModel(nn.Module):
"""SDClipModel-shaped Lens GPT-OSS encoder (multi-layer feature extractor)."""
def __init__(self, device="cpu", dtype=None, model_options=None, **kwargs):
super().__init__()
model_options = dict(model_options or {})
operations = model_options.get("custom_operations")
if operations is None:
quant_config = model_options.get("quantization_metadata") or {}
operations = comfy.ops.mixed_precision_ops(quant_config, dtype, full_precision_mm=True)
self.operations = operations
cfg_overrides = model_options.get("gpt_oss_config", {})
self.config = GptOss20BConfig(**cfg_overrides)
self.selected_layers = tuple(model_options.get("selected_layers", LENS_SELECTED_LAYERS))
self.txt_offset = int(model_options.get("txt_offset", LENS_TXT_OFFSET))
self.transformer = GptOssModel(self.config, device=device, dtype=dtype, ops=operations)
self.num_layers = self.config.num_hidden_layers
self.dtype = dtype
self.execution_device = None
self._pad_token_id = _LENS_PAD_TOKEN_ID
def set_clip_options(self, options):
self.execution_device = options.get("execution_device", self.execution_device)
def reset_clip_options(self):
self.execution_device = None
def _gather_tokens(self, token_weight_pairs):
ids_list = [[int(t[0]) for t in batch] for batch in token_weight_pairs]
pad_id = self._pad_token_id
max_len = max(len(x) for x in ids_list)
device = self.execution_device
ids = torch.full((len(ids_list), max_len), pad_id, dtype=torch.long, device=device)
mask = torch.zeros((len(ids_list), max_len), dtype=torch.long, device=device)
for i, x in enumerate(ids_list):
ids[i, : len(x)] = torch.tensor(x, dtype=torch.long, device=device)
mask[i, : len(x)] = 1
return ids, mask
def encode_token_weights(self, token_weight_pairs):
# Empty negative: emit zero-length features + zero mask
if all(len(batch) == 0 for batch in token_weight_pairs):
device = self.execution_device
B = len(token_weight_pairs)
L = len(self.selected_layers)
H = self.config.hidden_size
flat = torch.zeros(B, 0, L * H, dtype=self.dtype, device=device)
mask = torch.zeros(B, 0, dtype=torch.long, device=device)
return flat, None, {"attention_mask": mask, "num_layers_stacked": L}
input_ids, attn_mask = self._gather_tokens(token_weight_pairs)
out = self.transformer(input_ids, attention_mask=attn_mask, capture_layers=self.selected_layers)
layers = out["hidden_states"] # list of L × [B, S, H]
stacked = torch.stack(layers, dim=2) # [B, S, L, H]
offset = self.txt_offset
if stacked.shape[1] > offset:
stacked = stacked[:, offset:].contiguous()
mask_trim = attn_mask[:, offset:]
else:
stacked = stacked[:, :0]
mask_trim = attn_mask[:, :0]
B, S, L, H = stacked.shape
flat = stacked.reshape(B, S, L * H)
extra = {"attention_mask": mask_trim, "num_layers_stacked": L}
return flat, None, extra
def load_sd(self, sd):
return self.transformer.load_state_dict(sd, strict=False, assign=True)
class LensTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options=None):
super().__init__(device=device, dtype=dtype, name="gpt_oss", clip_model=LensGptOssClipModel, model_options=model_options or {})
def lens_te(dtype_llama=None, llama_quantization_metadata=None):
class LensTEModel_(LensTEModel):
def __init__(self, device="cpu", dtype=None, model_options=None):
mo = dict(model_options or {})
if llama_quantization_metadata is not None:
mo["quantization_metadata"] = llama_quantization_metadata
if dtype is None and dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=mo)
return LensTEModel_

View File

@ -0,0 +1,104 @@
import torch
from comfy import sd1_clip
from .lumina2 import Gemma2BTokenizer, LuminaModel
import comfy.text_encoders.llama
class PixelDiTGemma2_2BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(
device=device, layer=layer, layer_idx=layer_idx,
textmodel_json_config={}, dtype=dtype,
special_tokens={"start": 2, "pad": 0},
layer_norm_hidden_state=False,
model_class=comfy.text_encoders.llama.Gemma2_2B,
enable_attention_masks=attention_mask,
return_attention_masks=attention_mask,
model_options=model_options,
)
_PIXELDIT_CHI_PROMPT = (
'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions '
"suitable for image generation. Evaluate the level of detail in the user prompt:\n"
"- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, "
"and spatial relationships to create vivid and concrete scenes.\n"
"- If the prompt is already detailed, refine and enhance the existing details slightly without "
"overcomplicating.\n"
"Here are examples of how to transform or refine prompts:\n"
"- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, "
"sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.\n"
"- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring "
"glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus "
"passing by towering glass skyscrapers.\n"
"Please generate only the enhanced description for the prompt below and avoid including any "
"additional commentary or evaluations:\n"
"User Prompt: "
)
_PIXELDIT_MAX_LENGTH = 300
_PIXELDIT_CHI_PROMPT_DETECT_PREFIX = 'Given a user prompt, generate an "Enhanced prompt"'
class PixelDiTGemma2Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = {}
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data,
name="gemma2_2b", tokenizer=Gemma2BTokenizer)
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
if not text.strip():
return super().tokenize_with_weights("", return_word_ids=return_word_ids, disable_weights=True, min_length=_PIXELDIT_MAX_LENGTH)
chi_token_count = len(self.gemma2_2b.tokenizer(_PIXELDIT_CHI_PROMPT)["input_ids"])
combined = text if text.startswith(_PIXELDIT_CHI_PROMPT_DETECT_PREFIX) else _PIXELDIT_CHI_PROMPT + text
max_length_all = chi_token_count + _PIXELDIT_MAX_LENGTH - 2
out = super().tokenize_with_weights(combined, return_word_ids=return_word_ids,
disable_weights=True, min_length=max_length_all)
out["gemma2_2b"] = [out["gemma2_2b"][0][:max_length_all]]
return out
def untokenize(self, token_weight_pair):
return self.gemma2_2b.untokenize(token_weight_pair)
def state_dict(self):
return self.gemma2_2b.state_dict()
class PixelDiTGemma2TE(LuminaModel):
# PixelDiT's select_index: keep BOS + last 299 embeddings of the padded sequence.
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="gemma2_2b",
clip_model=PixelDiTGemma2_2BModel, model_options=model_options)
def encode_token_weights(self, token_weight_pairs):
result = super().encode_token_weights(token_weight_pairs)
cond, pooled = result[0], result[1]
extra = result[2] if len(result) > 2 else None
if cond.shape[1] > _PIXELDIT_MAX_LENGTH:
cond = torch.cat([cond[:, :1], cond[:, -(_PIXELDIT_MAX_LENGTH - 1):]], dim=1)
if extra is not None and "attention_mask" in extra:
am = extra["attention_mask"]
extra["attention_mask"] = torch.cat([am[..., :1], am[..., -(_PIXELDIT_MAX_LENGTH - 1):]], dim=-1)
if extra is not None:
return cond, pooled, extra
return cond, pooled
def pixeldit_te(dtype_llama=None, llama_quantization_metadata=None):
class PixelDiTTE_(PixelDiTGemma2TE):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["llama_quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
return PixelDiTTE_

View File

@ -86,6 +86,7 @@ def load_safetensors(ckpt):
import comfy_aimdo.model_mmap import comfy_aimdo.model_mmap
f = open(ckpt, "rb", buffering=0) f = open(ckpt, "rb", buffering=0)
file_lock = threading.Lock()
model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt) model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt)
file_size = os.path.getsize(ckpt) file_size = os.path.getsize(ckpt)
mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get())) mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get()))
@ -111,9 +112,8 @@ def load_safetensors(ckpt):
storage = tensor.untyped_storage() storage = tensor.untyped_storage()
setattr(storage, setattr(storage,
"_comfy_tensor_file_slice", "_comfy_tensor_file_slice",
comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start)) comfy.memory_management.TensorFileSlice(f, file_lock, data_base_offset + start, end - start))
setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv)) setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv))
setattr(storage, "_comfy_tensor_mmap_touched", False)
sd[name] = tensor sd[name] = tensor
return sd, header.get("__metadata__", {}), return sd, header.get("__metadata__", {}),
@ -1020,10 +1020,11 @@ def bislerp(samples, width, height):
def lanczos(samples, width, height): def lanczos(samples, width, height):
#the below API is strict and expects grayscale to be squeezed #the below API is strict and expects grayscale to be squeezed
if samples.ndim == 4:
samples = samples.squeeze(1) if samples.shape[1] == 1 else samples.movedim(1, -1) samples = samples.squeeze(1) if samples.shape[1] == 1 else samples.movedim(1, -1)
images = [Image.fromarray(np.clip(255. * image.cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples] images = [Image.fromarray(np.clip(255. * image.cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples]
images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images] images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images]
images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images] images = [torch.from_numpy(t).movedim(-1, 0) if (t := np.array(image).astype(np.float32) / 255.0).ndim == 3 else torch.from_numpy(t) for image in images]
result = torch.stack(images) result = torch.stack(images)
return result.to(samples.device, samples.dtype) return result.to(samples.device, samples.dtype)
@ -1451,4 +1452,3 @@ def deepcopy_list_dict(obj, memo=None):
memo[obj_id] = res memo[obj_id] = res
return res return res

View File

@ -1,52 +0,0 @@
import ctypes
import logging
import psutil
from ctypes import wintypes
import comfy_aimdo.control
psapi = ctypes.WinDLL("psapi")
kernel32 = ctypes.WinDLL("kernel32")
class PERFORMANCE_INFORMATION(ctypes.Structure):
_fields_ = [
("cb", wintypes.DWORD),
("CommitTotal", ctypes.c_size_t),
("CommitLimit", ctypes.c_size_t),
("CommitPeak", ctypes.c_size_t),
("PhysicalTotal", ctypes.c_size_t),
("PhysicalAvailable", ctypes.c_size_t),
("SystemCache", ctypes.c_size_t),
("KernelTotal", ctypes.c_size_t),
("KernelPaged", ctypes.c_size_t),
("KernelNonpaged", ctypes.c_size_t),
("PageSize", ctypes.c_size_t),
("HandleCount", wintypes.DWORD),
("ProcessCount", wintypes.DWORD),
("ThreadCount", wintypes.DWORD),
]
def get_free_ram():
#Windows is way too conservative and chalks recently used uncommitted model RAM
#as "in-use". So, calculate free RAM for the sake of general use as the greater of:
#
#1: What psutil says
#2: Total Memory - (Committed Memory - VRAM in use)
#
#We have to subtract VRAM in use from the comitted memory as WDDM creates a naked
#commit charge for all VRAM used just incase it wants to page it all out. This just
#isn't realistic so "overcommit" on our calculations by just subtracting it off.
pi = PERFORMANCE_INFORMATION()
pi.cb = ctypes.sizeof(pi)
if not psapi.GetPerformanceInfo(ctypes.byref(pi), pi.cb):
logging.warning("WARNING: Failed to query windows performance info. RAM usage may be sub optimal")
return psutil.virtual_memory().available
committed = pi.CommitTotal * pi.PageSize
total = pi.PhysicalTotal * pi.PageSize
return max(psutil.virtual_memory().available,
total - (committed - comfy_aimdo.control.get_total_vram_usage()))

View File

@ -1,5 +1,3 @@
from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from comfy_api.internal import ComfyAPIBase from comfy_api.internal import ComfyAPIBase

View File

@ -1,4 +1,3 @@
from __future__ import annotations
from av.container import InputContainer from av.container import InputContainer
from av.subtitles.stream import SubtitleStream from av.subtitles.stream import SubtitleStream
from fractions import Fraction from fractions import Fraction

View File

@ -762,14 +762,32 @@ class Accumulation(ComfyTypeIO):
@comfytype(io_type="LOAD3D_CAMERA") @comfytype(io_type="LOAD3D_CAMERA")
class Load3DCamera(ComfyTypeIO): class Load3DCamera(ComfyTypeIO):
class CameraInfo(TypedDict): class CameraInfo(TypedDict):
position: dict[str, float | int] # Coordinate system: right-handed, Y-up, camera looks down -Z
target: dict[str, float | int] position: dict[str, float | int] # scene units
zoom: int target: dict[str, float | int] # scene units; OrbitControls focus point
cameraType: str zoom: float | int # dimensionless, 1 = 100%
cameraType: str # 'perspective' | 'orthographic'
quaternion: NotRequired[dict[str, float | int]] # normalized, dimensionless; camera world rotation
fov: NotRequired[float | int] # degrees, vertical FOV (perspective only)
aspect: NotRequired[float | int] # width / height (perspective only)
near: NotRequired[float | int] # scene units
far: NotRequired[float | int] # scene units
frustum: NotRequired[dict[str, float | int]] # orthographic only: {left, right, top, bottom} in scene units
Type = CameraInfo Type = CameraInfo
@comfytype(io_type="LOAD3D_MODEL_INFO")
class Load3DModelInfo(ComfyTypeIO):
class Model3DTransform(TypedDict):
# Coordinate system: right-handed, Y-up, world space
position: dict[str, float | int] # scene units
quaternion: dict[str, float | int] # normalized, dimensionless; world rotation
scale: dict[str, float | int] # dimensionless multiplier
Type = list[Model3DTransform]
@comfytype(io_type="LOAD_3D") @comfytype(io_type="LOAD_3D")
class Load3D(ComfyTypeIO): class Load3D(ComfyTypeIO):
"""3D models are stored as a dictionary.""" """3D models are stored as a dictionary."""
@ -779,6 +797,7 @@ class Load3D(ComfyTypeIO):
normal: str normal: str
camera_info: Load3DCamera.CameraInfo camera_info: Load3DCamera.CameraInfo
recording: NotRequired[str] recording: NotRequired[str]
model_3d_info: NotRequired[list[Load3DModelInfo.Model3DTransform]]
Type = Model3DDict Type = Model3DDict
@ -2291,6 +2310,7 @@ __all__ = [
"FlowControl", "FlowControl",
"Accumulation", "Accumulation",
"Load3DCamera", "Load3DCamera",
"Load3DModelInfo",
"Load3D", "Load3D",
"Load3DAnimation", "Load3DAnimation",
"Photomaker", "Photomaker",

View File

@ -1,4 +1,3 @@
from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from fractions import Fraction from fractions import Fraction

View File

@ -3,7 +3,6 @@
# timestamp: 2025-07-30T08:54:00+00:00 # timestamp: 2025-07-30T08:54:00+00:00
# pylint: disable # pylint: disable
from __future__ import annotations
from datetime import date, datetime from datetime import date, datetime
from enum import Enum from enum import Enum

View File

@ -35,6 +35,19 @@ class AnthropicMessage(BaseModel):
content: list[AnthropicTextContent | AnthropicImageContent] = Field(...) content: list[AnthropicTextContent | AnthropicImageContent] = Field(...)
class AnthropicThinkingConfig(BaseModel):
type: Literal["enabled", "disabled", "adaptive"] = Field(...)
budget_tokens: int | None = Field(
None, ge=1024,
description="Reasoning budget in tokens. Used when type is 'enabled'. Must be less than max_tokens.",
)
class AnthropicOutputConfig(BaseModel):
"""Used with `thinking.type='adaptive'` on models like Opus 4.7."""
effort: Literal["low", "medium", "high"] | None = Field(None)
class AnthropicMessagesRequest(BaseModel): class AnthropicMessagesRequest(BaseModel):
model: str = Field(...) model: str = Field(...)
messages: list[AnthropicMessage] = Field(...) messages: list[AnthropicMessage] = Field(...)
@ -44,6 +57,8 @@ class AnthropicMessagesRequest(BaseModel):
top_p: float | None = Field(None, ge=0.0, le=1.0) top_p: float | None = Field(None, ge=0.0, le=1.0)
top_k: int | None = Field(None, ge=0) top_k: int | None = Field(None, ge=0)
stop_sequences: list[str] | None = Field(None) stop_sequences: list[str] | None = Field(None)
thinking: AnthropicThinkingConfig | None = Field(None)
output_config: AnthropicOutputConfig | None = Field(None)
class AnthropicResponseTextBlock(BaseModel): class AnthropicResponseTextBlock(BaseModel):
@ -51,6 +66,14 @@ class AnthropicResponseTextBlock(BaseModel):
text: str = Field(...) text: str = Field(...)
class AnthropicResponseThinkingBlock(BaseModel):
type: Literal["thinking"] = "thinking"
thinking: str = Field(...)
AnthropicResponseBlock = AnthropicResponseTextBlock | AnthropicResponseThinkingBlock
class AnthropicCacheCreationUsage(BaseModel): class AnthropicCacheCreationUsage(BaseModel):
ephemeral_5m_input_tokens: int | None = Field(None) ephemeral_5m_input_tokens: int | None = Field(None)
ephemeral_1h_input_tokens: int | None = Field(None) ephemeral_1h_input_tokens: int | None = Field(None)
@ -69,7 +92,7 @@ class AnthropicMessagesResponse(BaseModel):
type: str | None = Field(None) type: str | None = Field(None)
role: str | None = Field(None) role: str | None = Field(None)
model: str | None = Field(None) model: str | None = Field(None)
content: list[AnthropicResponseTextBlock] | None = Field(None) content: list[AnthropicResponseBlock] | None = Field(None)
stop_reason: str | None = Field(None) stop_reason: str | None = Field(None)
stop_sequence: str | None = Field(None) stop_sequence: str | None = Field(None)
usage: AnthropicMessagesUsage | None = Field(None) usage: AnthropicMessagesUsage | None = Field(None)

View File

@ -0,0 +1,32 @@
from pydantic import BaseModel, Field
class CreateSwitchXRequest(BaseModel):
generation_type: str = Field(...)
source_uri: str = Field(...)
alpha_mode: str = Field(...)
prompt: str | None = Field(None, max_length=2000)
reference_image_uri: str | None = Field(None)
alpha_uri: str | None = Field(None)
max_resolution: int = Field(1080)
callback_url: str | None = Field(None)
idempotency_key: str | None = Field(None, max_length=256, min_length=1)
class SwitchXOutputUrls(BaseModel):
render: str | None = Field(None)
source: str | None = Field(None)
alpha: str | None = Field(None)
class SwitchXStatusResponse(BaseModel):
id: str = Field(...)
status: str = Field(...)
progress: int | None = Field(None)
generation_type: str | None = Field(None)
alpha_mode: str | None = Field(None)
output: SwitchXOutputUrls | None = Field(None)
error: str | None = Field(None)
created_at: str | None = Field(None)
modified_at: str | None = Field(None)
completed_at: str | None = Field(None)

View File

@ -1,5 +1,3 @@
from __future__ import annotations
from enum import Enum from enum import Enum
from typing import Any, Dict, Optional from typing import Any, Dict, Optional

View File

@ -158,8 +158,9 @@ class SeedanceCreateAssetResponse(BaseModel):
class SeedanceVirtualLibraryCreateAssetRequest(BaseModel): class SeedanceVirtualLibraryCreateAssetRequest(BaseModel):
url: str = Field(..., description="Publicly accessible URL of the image asset to upload.") url: str = Field(..., description="Publicly accessible URL of the asset to upload.")
hash: str = Field(..., description="Dedup key. Re-submitting the same hash returns the existing asset id.") hash: str = Field(..., description="Dedup key. Re-submitting the same hash returns the existing asset id.")
asset_type: str | None = Field(None, description="BytePlus asset type. Defaults to Image server-side when omitted.")
# Dollars per 1K tokens, keyed by (model_id, has_video_input). # Dollars per 1K tokens, keyed by (model_id, has_video_input).

View File

@ -0,0 +1,46 @@
"""Pydantic models for the Krea image-generation API."""
from pydantic import BaseModel, Field
class KreaMoodboard(BaseModel):
id: str = Field(...)
strength: float = Field(default=0.35, ge=-0.5, le=1.5)
class KreaImageStyleReference(BaseModel):
strength: float = Field(..., ge=-2.0, le=2.0)
url: str | None = Field(default=None)
class KreaGenerateImageRequest(BaseModel):
prompt: str = Field(...)
aspect_ratio: str = Field(...)
resolution: str = Field(...)
seed: int | None = Field(default=None)
creativity: str = Field(default="medium")
moodboards: list[KreaMoodboard] | None = Field(default=None)
image_style_references: list[KreaImageStyleReference] | None = Field(default=None)
class KreaJobResult(BaseModel):
urls: list[str] | None = Field(default=None)
style_id: str | None = Field(default=None)
class KreaJob(BaseModel):
job_id: str = Field(...)
status: str = Field(...)
created_at: str = Field(...)
completed_at: str | None = Field(default=None)
result: KreaJobResult | None = Field(default=None)
class KreaAssetResponse(BaseModel):
id: str = Field(...)
image_url: str = Field(...)
uploaded_at: str = Field(...)
width: float | None = Field(default=None)
height: float | None = Field(default=None)
size_bytes: float | None = Field(default=None)
mime_type: str | None = Field(default=None)

Some files were not shown because too many files have changed in this diff Show More