Compare commits

...

19 Commits

Author SHA1 Message Date
Ray Suhyun Lee
606bd36d90
Merge b947b5a4a3 into 95f6652ef5 2026-05-10 15:48:42 +08:00
LaVie024
95f6652ef5
Add Boolean support to Math Expression Node (#13224)
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Generate Pydantic Stubs from api.comfy.org / generate-models (push) Has been cancelled
* Add Boolean support to math expressions

* Change boolean result test to assert values

---------

Co-authored-by: Alexis Rolland <alexisrolland@hotmail.com>
2026-05-10 15:33:47 +08:00
comfyanonymous
20f5e474da
Use LatentCutToBatch instead. (#13815)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Removed VAEDecodeVideoFramewise from nodes_wandancer.py.
2026-05-09 14:17:00 -07:00
Jukka Seppänen
3200f28e3a
Support Wan-Dancer (#13813)
* initial WanDancer support

* nodes_wandancer: Add list form of chunker.

Create an alternate list form of the node so the chunk gens can be
trivially looped by the comfy executor.

* Closer match to original soxr resampling

* Remove librosa node

* Cleanup

---------

Co-authored-by: Rattus <rattus128@gmail.com>
2026-05-09 14:02:56 -07:00
Comfy Org PR Bot
a4b7e3beed
Bump comfyui-frontend-package to 1.43.18 (#13809)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2026-05-09 07:53:10 -07:00
Alexander Piskun
7bbf1e8169
[Partner Nodes] Tripo3D 3.1 model (#13788)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
* feat(api-nodes): add Tripo3D 3.1 model

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* fix: price badges algo

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* [Partner Nodes] deprecate "quad" param for the TripoMultiviewToModel node

Signed-off-by: bigcat88 <bigcat88@icloud.com>

---------

Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-05-08 21:38:17 -07:00
lin-bot23
8b08bfdcbe
Add description field to blueprint subgraphs (#13797)
* Add description field to all blueprint subgraphs

Sets the 'description' field on every subgraph blueprint node,
which will show on the node preview and tooltip. Covers all 51
blueprint files under blueprints/.

* Update blueprint descriptions with researched model info

* Refine blueprint descriptions with researched model specs from docs

Updates subgraph descriptions across all 51 blueprints with accurate
model details drawn from ComfyUI docs, including:
- Flux.1 Dev: 12B open-weights, Pro-level quality
- Flux.2 Klein 4B: fastest Flux, distilled architecture
- Qwen-Image: 20B MMDiT, multilingual text rendering
- Z-Image-Turbo: distilled 6B DiT, sub-second inference
- LTX-2/2.3: 19B DiT audio-video foundation model
- Wan2.2: open-source, 14B/1.3B variants
- ACE-Step 1.5: ~1s full-song generation
- GPU shader nodes consistently labeled as fragment shaders

* Strip marketing fluff and license info from descriptions

* Fix Canny to Video (LTX 2.0) description

* Remove 'local-' prefix from subgraph names

* Preserve UTF-8 encoding in JSON files (ensure_ascii=False)

* Apply review suggestions from alexisrolland

- Rename 'Image to Model (Hunyuan3d 2.1)' -> 'Image to 3D Model (Hunyuan3d 2.1)'
- Rename 'Image Upscale(Z-image-Turbo)' -> 'Image Upscale (Z-image-Turbo)'
- Rename 'Video Inpaint(Wan2.1 VACE)' -> 'Video Inpaint (Wan 2.1 VACE)'
- Use 'Black Forest Labs' branding in Flux descriptions
- Use 'Google's Gemini' with possessive in captioning nodes
- Normalize 'Wan 2.2' and 'Wan 2.1' spacing in descriptions

* fix: revert Color Adjustment.json to preserve original GLSL shader content

Only adds the 'description' field without modifying the shader code
(which contained Unicode escape \\u2192 that should be preserved).

* Apply CodeRabbit review suggestions

- Color Adjustment: include vibrance in description
- Image Blur: expand to Gaussian/Box/Radial modes
- Flux.2 Klein 4B: narrow to image edit only (no T2I)
- NetaYume Lumina: correct model base (Neta Lumina, not Lumina-Next)

---------

Co-authored-by: linmoumou <linmoumou@linmoumoudeMac-mini.local>
Co-authored-by: Daxiong (Lin) <contact@comfyui-wiki.com>
2026-05-09 11:26:13 +08:00
Matt Miller
4e823431cc
Add cloud-runtime experiment node-schema endpoints to spec (#13806)
* Add cloud-runtime experiment node-schema endpoints to spec

Replace the GET operations at /api/experiment/nodes and
/api/experiment/nodes/{id} with getNodeInfoSchema and getNodeByID —
the optimized, ETag-tagged object_info schema endpoints the cloud
frontend depends on for the workflow editor.

Each operation is tagged x-runtime: [cloud] and uses the runtime-only
tag for cloud-side codegen exclusion. Response headers document the
ETag and Cache-Control validators; 304 Not Modified is declared for
RFC 7232 conditional GETs.

Remove the now-unused CloudNodeList schema to keep Spectral clean.

Co-authored-by: Matt Miller <MillerMedia@users.noreply.github.com>

* spec: document If-None-Match header on conditional GET endpoints

Both `getNodeInfoSchema` and `getNodeByID` advertise `ETag` response
headers and a `304 Not Modified` response, but the spec didn't declare
the `If-None-Match` request header that triggers conditional validation.
Adding it as an optional header parameter on both ops so client codegen
exposes the conditional-GET pattern.
2026-05-08 19:14:23 -07:00
comfyanonymous
66669b2ded
I don't think there was any because nobody complained. (#13807)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-05-08 17:32:14 -07:00
Alexander Piskun
65045730a6
[Partner Nodes] additionally use Baidu server to detect the accessibility of internet (#13803)
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-05-08 13:11:52 -07:00
Matt Miller
87878f354f
Add cloud-runtime FE-facing operations to spec (#13734)
* Add cloud-runtime FE-facing operations to openapi.yaml

Add ~67 cloud-runtime FE-facing path operations to the core OpenAPI spec,
each tagged with x-runtime: [cloud] at the operation level. These operations
are served by the cloud runtime; the local runtime returns 404 for all of
these paths.

Domain groups added:
- Jobs / prompts: /api/job/*, /api/jobs/*/cancel, /api/prompt/*, etc.
- History v2: /api/history_v2, /api/history_v2/{prompt_id}
- Cloud logs: /api/logs
- Asset extensions: /api/assets/download, export, import, etc.
- Custom nodes: /api/experiment/nodes (cloud install/uninstall)
- Hub: /api/hub/profiles, /api/hub/workflows, /api/hub/labels, etc.
- Workflows: /api/workflows CRUD, versioning, fork, publish
- Auth/session: /api/auth/session, /api/auth/token, /.well-known/jwks.json
- Billing: /api/billing/balance, plans, subscribe, topup, etc.
- Workspace: /api/workspace/*, /api/workspaces/*
- User/settings/misc: /api/user, /api/secrets, /api/feedback, etc.

Also adds corresponding cloud-only component schemas (CloudJob, CloudWorkflow,
BillingPlan, Workspace, HubProfile, AuthSession, etc.), all tagged with
x-runtime: [cloud].

Spectral lint passes under the existing ruleset with zero new warnings.

* Add job_id field to Asset schema and deprecate prompt_id (#13736)

- Add job_id as a nullable UUID field to the Asset schema
- Mark prompt_id as deprecated with note pointing to job_id
- No x-runtime tag needed as both runtimes populate the field

* Add hash field to Asset schemas and deprecate asset_hash (#13738)

- Add 'hash' as a nullable string field to Asset and AssetUpdated schemas
- Mark 'asset_hash' as deprecated with a note pointing to 'hash'
- AssetCreated inherits 'hash' via allOf from Asset
- Spectral lint clean (no new warnings)

* Fix method drift on cloud-runtime endpoints

Three PUT operations were added that should be PATCH (cloud serves
PATCH for partial updates):

- /api/workflows/{workflow_id}
- /api/workspaces/{id}
- /api/workspace/members/{userId}

Two POST operations were added that should be GET (cloud serves GET
with query params):

- /api/assets/remote-metadata (url moves to query param)
- /api/files/mask-layers (response shape replaced — operation queries
  related mask layer filenames, not file uploads)

* Add missing cloud-runtime operations and schemas

PR review surfaced operations the cloud runtime serves that weren't
covered by the initial spec push, plus one path family missed entirely.

New methods on existing paths:

- /api/auth/session: add POST (create session cookie) and DELETE (logout)
- /api/secrets/{id}: add GET (read metadata) and PATCH (update)
- /api/hub/profiles: add POST (create profile)
- /api/hub/workflows: add POST (publish to hub)
- /api/hub/workflows/{share_id}: add DELETE (unpublish)
- /api/workspaces/{id}: add DELETE (soft-delete workspace)
- /api/workspace/members/{user_id}/api-keys: add DELETE (bulk revoke)
- /api/workflows/{workflow_id}/versions: add POST (create new version)
- /api/userdata/{file}/publish: add GET (read publish info)

New path family:

- /api/tasks (GET list) and /api/tasks/{task_id} (GET detail) for the
  background task framework

New component schemas (all tagged x-runtime: [cloud]):

CreateSessionResponse, DeleteSessionResponse, UpdateSecretRequest,
BulkRevokeAPIKeysResponse, CreateHubProfileRequest, PublishHubWorkflowRequest,
HubWorkflowDetail, AssetInfo, CreateWorkflowVersionRequest,
WorkflowVersionResponse, WorkflowPublishInfo, TaskEntry, TaskResponse,
TasksListResponse. Existing SecretMeta extended with provider and
last_used_at fields the cloud runtime actually returns.

New tag: task. Spectral lint passes with zero errors.

* Add job_id and prompt_id to AssetUpdated schema

Mirrors the Asset schema's deprecation pattern: prompt_id is marked
deprecated with a description pointing to job_id; job_id is the new
preferred field. PUT /api/assets/{id} responses can now carry both fields
consistent with the other Asset-returning endpoints.

* feat: add width and height fields to Asset schema (#13745)

Add nullable integer fields 'width' and 'height' to the Asset schema
in openapi.yaml. These expose original image dimensions in pixels for
clients that need pre-thumbnail size info. Both fields are null for
non-image assets or assets ingested before dimension extraction.

Co-authored-by: Matt Miller <MillerMedia@users.noreply.github.com>

* Remove /api/job/{job_id} and /api/job/{job_id}/outputs

These two paths are not actually served by the cloud runtime — they
return 404 with a redirect message pointing callers to the canonical
`/api/jobs/{job_id}` (plural). Declaring them with `x-runtime: [cloud]`
and a 200 response schema is incorrect.

`/api/job/{job_id}/status` stays — it is a real cloud-served endpoint.

Also drops the now-orphaned `CloudJob` and `CloudJobOutputs` component
schemas. `CloudJobStatus` is retained.
2026-05-08 12:39:16 -07:00
Alexis Rolland
c5ecd231a2
fix: Fix bug when mask not on same device (CORE-181) (#13801)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-05-08 23:06:29 +08:00
drozbay
9864f5ac86
fix: Stop LTXVImgToVideoInplace from mutating input latents and dropping noise_mask (#13793) 2026-05-08 23:02:17 +08:00
drozbay
05cd076bc1
fix: Make LTXVAddGuide center-crop guide images to match other LTXV nodes (#13794) 2026-05-08 22:48:59 +08:00
Yousef R. Gamaleldin
d3c18c1636
Add support for BiRefNet background remove model (CORE-46) (#12747) 2026-05-08 17:59:24 +08:00
omahs
bac6fc35fb
Fix typos (#10986) 2026-05-08 17:14:45 +08:00
Alexander Piskun
56c74094c7
[Partner Nodes] use "adaptive" aspect ratio for SD2 nodes (#13800)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-05-07 23:39:13 -07:00
Ray Suhyun Lee
b947b5a4a3 fix other conditions as well 2023-09-28 18:23:58 +09:00
Ray Suhyun Lee
297757778d fix /view returning image with wrong orientation 2023-09-28 17:49:35 +09:00
82 changed files with 7143 additions and 218 deletions

View File

@ -431,9 +431,10 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image Tools/Color adjust"
"category": "Image Tools/Color adjust",
"description": "Adjusts image brightness and contrast using a real-time GPU fragment shader."
}
]
},
"extra": {}
}
}

View File

@ -162,7 +162,7 @@
},
"revision": 0,
"config": {},
"name": "local-Canny to Image (Z-Image-Turbo)",
"name": "Canny to Image (Z-Image-Turbo)",
"inputNode": {
"id": -10,
"bounding": [
@ -1553,7 +1553,8 @@
"VHS_MetadataImage": true,
"VHS_KeepIntermediate": true
},
"category": "Image generation and editing/Canny to image"
"category": "Image generation and editing/Canny to image",
"description": "Generates an image from a Canny edge map using Z-Image-Turbo, with text conditioning."
}
]
},
@ -1574,4 +1575,4 @@
}
},
"version": 0.4
}
}

View File

@ -192,7 +192,7 @@
},
"revision": 0,
"config": {},
"name": "local-Canny to Video (LTX 2.0)",
"name": "Canny to Video (LTX 2.0)",
"inputNode": {
"id": -10,
"bounding": [
@ -3600,7 +3600,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Video generation and editing/Canny to video"
"category": "Video generation and editing/Canny to video",
"description": "Generates video from Canny edge maps using LTX-2, with optional synchronized audio."
}
]
},
@ -3616,4 +3617,4 @@
}
},
"version": 0.4
}
}

View File

@ -377,8 +377,9 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image Tools/Color adjust"
"category": "Image Tools/Color adjust",
"description": "Adds lens-style chromatic aberration (color fringing) using a real-time GPU fragment shader."
}
]
}
}
}

View File

@ -596,7 +596,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image Tools/Color adjust"
"category": "Image Tools/Color adjust",
"description": "Adjusts saturation, temperature, tint, and vibrance using a real-time GPU fragment shader."
}
]
}

View File

@ -1129,7 +1129,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image Tools/Color adjust"
"category": "Image Tools/Color adjust",
"description": "Balances colors across shadows, midtones, and highlights using a real-time GPU fragment shader."
}
]
}

View File

@ -608,7 +608,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image Tools/Color adjust"
"category": "Image Tools/Color adjust",
"description": "Fine-tunes tone and color with per-channel curve adjustments using a real-time GPU fragment shader."
}
]
}

View File

@ -1609,7 +1609,8 @@
}
],
"extra": {},
"category": "Image Tools/Crop"
"category": "Image Tools/Crop",
"description": "Splits an image into a 2×2 grid of four equal tiles."
}
]
},

View File

@ -2946,7 +2946,8 @@
}
],
"extra": {},
"category": "Image Tools/Crop"
"category": "Image Tools/Crop",
"description": "Splits an image into a 3×3 grid of nine equal tiles."
}
]
},

View File

@ -1579,7 +1579,8 @@
"VHS_MetadataImage": true,
"VHS_KeepIntermediate": true
},
"category": "Image generation and editing/Depth to image"
"category": "Image generation and editing/Depth to image",
"description": "Generates an image from a depth map using Z-Image-Turbo with text conditioning."
},
{
"id": "458bdf3c-4b58-421c-af50-c9c663a4d74c",
@ -2461,7 +2462,8 @@
]
},
"workflowRendererVersion": "LG"
}
},
"description": "Estimates a monocular depth map from an input image using the Lotus depth estimation model."
}
]
},

View File

@ -4233,7 +4233,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Video generation and editing/Depth to video"
"category": "Video generation and editing/Depth to video",
"description": "Generates video from depth maps using LTX-2, with optional synchronized audio."
},
{
"id": "38b60539-50a7-42f9-a5fe-bdeca26272e2",
@ -5192,7 +5193,8 @@
],
"extra": {
"workflowRendererVersion": "LG"
}
},
"description": "Estimates a monocular depth map from an input image using the Lotus depth estimation model."
}
]
},

View File

@ -450,9 +450,10 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image Tools/Blur"
"category": "Image Tools/Blur",
"description": "Applies bilateral (edge-preserving) blur to soften images while retaining detail."
}
]
},
"extra": {}
}
}

View File

@ -580,8 +580,9 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image Tools/Color adjust"
"category": "Image Tools/Color adjust",
"description": "Adds procedural film grain texture for a cinematic look via GPU fragment shader."
}
]
}
}
}

View File

@ -3350,7 +3350,8 @@
}
],
"extra": {},
"category": "Video generation and editing/First-Last-Frame to Video"
"category": "Video generation and editing/First-Last-Frame to Video",
"description": "Generates a video interpolating between first and last keyframes using LTX-2.3."
}
]
},

View File

@ -575,8 +575,9 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image Tools/Color adjust"
"category": "Image Tools/Color adjust",
"description": "Adds a glow/bloom effect around bright image areas via GPU fragment shader."
}
]
}
}
}

View File

@ -752,8 +752,9 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image Tools/Color adjust"
"category": "Image Tools/Color adjust",
"description": "Adjusts hue, saturation, and lightness of an image using a real-time GPU fragment shader."
}
]
}
}
}

View File

@ -374,7 +374,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image Tools/Blur"
"category": "Image Tools/Blur",
"description": "Applies Gaussian, Box, or Radial blur to soften images and create stylized depth or motion effects."
}
]
}

View File

@ -310,7 +310,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Text generation/Image Captioning"
"category": "Text generation/Image Captioning",
"description": "Generates descriptive captions for images using Google's Gemini multimodal LLM."
}
]
}

View File

@ -315,8 +315,9 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image Tools/Color adjust"
"category": "Image Tools/Color adjust",
"description": "Manipulates individual RGBA channels for masking, compositing, and channel effects."
}
]
}
}
}

View File

@ -2138,7 +2138,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image generation and editing/Edit image"
"category": "Image generation and editing/Edit image",
"description": "Edits images via text instructions using FireRed Image Edit 1.1, a diffusion-based instruction-following editing model."
}
]
},

View File

@ -1472,7 +1472,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image generation and editing/Edit image"
"category": "Image generation and editing/Edit image",
"description": "Edits an input image via text instructions using FLUX.2 [klein] 4B."
},
{
"id": "6007e698-2ebd-4917-84d8-299b35d7b7ab",
@ -1821,7 +1822,8 @@
],
"extra": {
"workflowRendererVersion": "LG"
}
},
"description": "Applies reference image conditioning for style/identity transfer (Flux.2 Klein 4B)."
}
]
},
@ -1837,4 +1839,4 @@
}
},
"version": 0.4
}
}

View File

@ -1417,7 +1417,8 @@
}
],
"extra": {},
"category": "Image generation and editing/Edit image"
"category": "Image generation and editing/Edit image",
"description": "Edits images via text instructions using LongCat Image Edit, an instruction-following image editing diffusion model."
}
]
},

View File

@ -132,7 +132,7 @@
},
"revision": 0,
"config": {},
"name": "local-Image Edit (Qwen 2511)",
"name": "Image Edit (Qwen 2511)",
"inputNode": {
"id": -10,
"bounding": [
@ -1468,7 +1468,8 @@
"VHS_MetadataImage": true,
"VHS_KeepIntermediate": true
},
"category": "Image generation and editing/Edit image"
"category": "Image generation and editing/Edit image",
"description": "Edits images via text instructions using Qwen-Image-Edit-2511 with improved character consistency and integrated LoRA."
}
]
},
@ -1489,4 +1490,4 @@
}
},
"version": 0.4
}
}

View File

@ -1188,7 +1188,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image generation and editing/Inpaint image"
"category": "Image generation and editing/Inpaint image",
"description": "Inpaints masked image regions using Flux.1 fill [dev], Black Forest Labs' inpainting/outpainting model."
}
]
},
@ -1202,4 +1203,4 @@
},
"ue_links": []
}
}
}

View File

@ -1548,7 +1548,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image generation and editing/Inpaint image"
"category": "Image generation and editing/Inpaint image",
"description": "Inpaints masked regions using Qwen-Image, extending its multilingual text rendering to inpainting tasks."
},
{
"id": "56a1f603-fbd2-40ed-94ef-c9ecbd96aca8",
@ -1907,7 +1908,8 @@
],
"extra": {
"workflowRendererVersion": "LG"
}
},
"description": "Expands and softens mask edges to reduce visible seams after image processing."
}
]
},

View File

@ -742,9 +742,10 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image Tools/Color adjust"
"category": "Image Tools/Color adjust",
"description": "Adjusts black point, white point, and gamma for tonal range control via GPU shader."
}
]
},
"extra": {}
}
}

View File

@ -1919,7 +1919,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image generation and editing/Outpaint image"
"category": "Image generation and editing/Outpaint image",
"description": "Outpaints beyond image boundaries using Qwen-Image's outpainting capabilities."
},
{
"id": "f93c215e-c393-460e-9534-ed2c3d8a652e",
@ -2278,7 +2279,8 @@
],
"extra": {
"workflowRendererVersion": "LG"
}
},
"description": "Expands and softens mask edges to reduce visible seams after image processing."
},
{
"id": "2a4b2cc0-db37-4302-a067-da392f38f06b",
@ -2733,7 +2735,8 @@
],
"extra": {
"workflowRendererVersion": "LG"
}
},
"description": "Scales both image and mask together while preserving alignment for editing workflows."
}
]
},

View File

@ -141,7 +141,7 @@
},
"revision": 0,
"config": {},
"name": "local-Image Upscale(Z-image-Turbo)",
"name": "Image Upscale (Z-image-Turbo)",
"inputNode": {
"id": -10,
"bounding": [
@ -1302,7 +1302,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image generation and editing/Enhance"
"category": "Image generation and editing/Enhance",
"description": "Upscales images to higher resolution using Z-Image-Turbo."
}
]
},

View File

@ -99,7 +99,7 @@
},
"revision": 0,
"config": {},
"name": "local-Image to Depth Map (Lotus)",
"name": "Image to Depth Map (Lotus)",
"inputNode": {
"id": -10,
"bounding": [
@ -948,7 +948,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image generation and editing/Depth to image"
"category": "Image generation and editing/Depth to image",
"description": "Estimates a monocular depth map from an input image using the Lotus depth estimation model."
}
]
},
@ -964,4 +965,4 @@
"workflowRendererVersion": "LG"
},
"version": 0.4
}
}

View File

@ -1586,7 +1586,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image generation and editing/Image to layers"
"category": "Image generation and editing/Image to layers",
"description": "Decomposes an image into variable-resolution RGBA layers for independent editing using Qwen-Image-Layered."
}
]
},

View File

@ -72,7 +72,7 @@
},
"revision": 0,
"config": {},
"name": "local-Image to Model (Hunyuan3d 2.1)",
"name": "Image to 3D Model (Hunyuan3d 2.1)",
"inputNode": {
"id": -10,
"bounding": [
@ -765,7 +765,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "3D/Image to 3D Model"
"category": "3D/Image to 3D Model",
"description": "Generates 3D mesh models from a single input image using Hunyuan3D 2.0/2.1."
}
]
},

View File

@ -4223,7 +4223,8 @@
"extra": {
"workflowRendererVersion": "Vue-corrected"
},
"category": "Video generation and editing/Image to video"
"category": "Video generation and editing/Image to video",
"description": "Generates video from a single input image using LTX-2.3."
}
]
},

View File

@ -206,7 +206,7 @@
},
"revision": 0,
"config": {},
"name": "local-Image to Video (Wan 2.2)",
"name": "Image to Video (Wan 2.2)",
"inputNode": {
"id": -10,
"bounding": [
@ -2027,7 +2027,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Video generation and editing/Image to video"
"category": "Video generation and editing/Image to video",
"description": "Generates video from an image and text prompt using Wan 2.2, supporting T2V and I2V."
}
]
},

View File

@ -134,7 +134,7 @@
},
"revision": 0,
"config": {},
"name": "local-Pose to Image (Z-Image-Turbo)",
"name": "Pose to Image (Z-Image-Turbo)",
"inputNode": {
"id": -10,
"bounding": [
@ -1298,7 +1298,8 @@
"VHS_MetadataImage": true,
"VHS_KeepIntermediate": true
},
"category": "Image generation and editing/Pose to image"
"category": "Image generation and editing/Pose to image",
"description": "Generates an image from pose keypoints using Z-Image-Turbo with text conditioning."
}
]
},
@ -1319,4 +1320,4 @@
}
},
"version": 0.4
}
}

View File

@ -3870,7 +3870,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Video generation and editing/Pose to video"
"category": "Video generation and editing/Pose to video",
"description": "Generates video from pose reference frames using LTX-2, with optional synchronized audio."
}
]
},

View File

@ -270,9 +270,10 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Text generation/Prompt enhance"
"category": "Text generation/Prompt enhance",
"description": "Expands short text prompts into detailed descriptions using a text generation model for better generation quality."
}
]
},
"extra": {}
}
}

View File

@ -302,8 +302,9 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image Tools/Sharpen"
"category": "Image Tools/Sharpen",
"description": "Sharpens image details using a GPU fragment shader for enhanced clarity."
}
]
}
}
}

View File

@ -222,7 +222,7 @@
},
"revision": 0,
"config": {},
"name": "local-Text to Audio (ACE-Step 1.5)",
"name": "Text to Audio (ACE-Step 1.5)",
"inputNode": {
"id": -10,
"bounding": [
@ -1502,7 +1502,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Audio/Music generation"
"category": "Audio/Music generation",
"description": "Generates audio/music from text prompts using ACE-Step 1.5, a diffusion-based audio generation model."
}
]
},
@ -1518,4 +1519,4 @@
}
},
"version": 0.4
}
}

View File

@ -1029,7 +1029,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image generation and editing/Text to image"
"category": "Image generation and editing/Text to image",
"description": "Generates images from text prompts using Flux.1 [dev], Black Forest Labs' 12B diffusion model."
}
]
},
@ -1043,4 +1044,4 @@
},
"ue_links": []
}
}
}

View File

@ -1023,7 +1023,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image generation and editing/Text to image"
"category": "Image generation and editing/Text to image",
"description": "Generates images from text prompts using Flux.1 Krea Dev, a Black Forest Labs × Krea collaboration variant."
}
]
},
@ -1037,4 +1038,4 @@
},
"ue_links": []
}
}
}

View File

@ -1104,7 +1104,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image generation and editing/Text to image"
"category": "Image generation and editing/Text to image",
"description": "Generates images from text prompts using NetaYume Lumina, fine-tuned from Neta Lumina for anime-style and illustration generation."
},
{
"id": "a07fdf06-1bda-4dac-bdbd-63ee8ebca1c9",
@ -1458,11 +1459,12 @@
],
"extra": {
"workflowRendererVersion": "LG"
}
},
"description": "Encodes a negative text prompt via CLIP for classifier-free guidance in anime-style generation (NetaYume Lumina)."
}
]
},
"extra": {
"ue_links": []
}
}
}

View File

@ -1941,7 +1941,8 @@
"extra": {
"workflowRendererVersion": "Vue-corrected"
},
"category": "Image generation and editing/Text to image"
"category": "Image generation and editing/Text to image",
"description": "Generates images from text prompts using Qwen-Image-2512, with enhanced human realism and finer natural detail over the base version."
}
]
},

View File

@ -1873,7 +1873,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image generation and editing/Text to image"
"category": "Image generation and editing/Text to image",
"description": "Generates images from text prompts using Qwen-Image, Alibaba's 20B MMDiT model with excellent multilingual text rendering."
}
]
},

View File

@ -149,7 +149,7 @@
},
"revision": 0,
"config": {},
"name": "local-Text to Image (Z-Image-Turbo)",
"name": "Text to Image (Z-Image-Turbo)",
"inputNode": {
"id": -10,
"bounding": [
@ -1054,7 +1054,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image generation and editing/Text to image"
"category": "Image generation and editing/Text to image",
"description": "Generates images from text prompts using Z-Image-Turbo, Alibaba's distilled 6B DiT model."
}
]
},
@ -1075,4 +1076,4 @@
}
},
"version": 0.4
}
}

View File

@ -4286,7 +4286,8 @@
"extra": {
"workflowRendererVersion": "Vue-corrected"
},
"category": "Video generation and editing/Text to video"
"category": "Video generation and editing/Text to video",
"description": "Generates video from text prompts using LTX-2.3, Lightricks' video diffusion model."
}
]
},

View File

@ -1572,7 +1572,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Video generation and editing/Text to video"
"category": "Video generation and editing/Text to video",
"description": "Generates video from text prompts using Wan2.2, Alibaba's diffusion video model."
}
]
},
@ -1586,4 +1587,4 @@
"VHS_KeepIntermediate": true
},
"version": 0.4
}
}

View File

@ -434,8 +434,9 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image Tools/Sharpen"
"category": "Image Tools/Sharpen",
"description": "Enhances edge contrast via unsharp masking for a sharper image appearance."
}
]
}
}
}

View File

@ -307,7 +307,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Text generation/Video Captioning"
"category": "Text generation/Video Captioning",
"description": "Generates descriptive captions for video input using Google's Gemini multimodal LLM."
}
]
}

View File

@ -165,7 +165,7 @@
},
"revision": 0,
"config": {},
"name": "local-Video Inpaint(Wan2.1 VACE)",
"name": "Video Inpaint (Wan 2.1 VACE)",
"inputNode": {
"id": -10,
"bounding": [
@ -2368,7 +2368,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Video generation and editing/Inpaint video"
"category": "Video generation and editing/Inpaint video",
"description": "Inpaints masked regions in video frames using Wan 2.1 VACE."
}
]
},

View File

@ -584,8 +584,9 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Video Tools/Stitch videos"
"category": "Video Tools/Stitch videos",
"description": "Stitches multiple video clips into a single sequential video file."
}
]
}
}
}

View File

@ -412,9 +412,10 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Video generation and editing/Enhance video"
"category": "Video generation and editing/Enhance video",
"description": "Upscales video to 4× resolution using a GAN-based upscaling model."
}
]
},
"extra": {}
}
}

View File

@ -0,0 +1,7 @@
{
"model_type": "birefnet",
"image_std": [1.0, 1.0, 1.0],
"image_mean": [0.0, 0.0, 0.0],
"image_size": 1024,
"resize_to_original": true
}

View File

@ -0,0 +1,689 @@
import torch
import comfy.ops
import numpy as np
import torch.nn as nn
from functools import partial
import torch.nn.functional as F
from torchvision.ops import deform_conv2d
from comfy.ldm.modules.attention import optimized_attention_for_device
CXT = [3072, 1536, 768, 384][1:][::-1][-3:]
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, device=None, dtype=None, operations=None):
super().__init__()
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.q = operations.Linear(dim, dim, bias=qkv_bias, device=device, dtype=dtype)
self.kv = operations.Linear(dim, dim * 2, bias=qkv_bias, device=device, dtype=dtype)
self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
def forward(self, x):
B, N, C = x.shape
optimized_attention = optimized_attention_for_device(x.device, mask=False, small_input=True)
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
x = optimized_attention(
q, k, v, heads=self.num_heads, skip_output_reshape=True, skip_reshape=True
).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, device=None, dtype=None, operations=None):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = operations.Linear(in_features, hidden_features, device=device, dtype=dtype)
self.act = nn.GELU()
self.fc2 = operations.Linear(hidden_features, out_features, device=device, dtype=dtype)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
def window_partition(x, window_size):
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, device=None, dtype=None, operations=None):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads, device=device, dtype=dtype))
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, device=device, dtype=dtype)
self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.long().view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
return x
class SwinTransformerBlock(nn.Module):
def __init__(self, dim, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None,
norm_layer=nn.LayerNorm, device=None, dtype=None, operations=None):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
self.norm1 = norm_layer(dim, device=device, dtype=dtype)
self.attn = WindowAttention(
dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, device=device, dtype=dtype, operations=operations)
self.norm2 = norm_layer(dim, device=device, dtype=dtype)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, device=device, dtype=dtype, operations=operations)
self.H = None
self.W = None
def forward(self, x, mask_matrix):
B, L, C = x.shape
H, W = self.H, self.W
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
pad_l = pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
attn_mask = mask_matrix
else:
shifted_x = x
attn_mask = None
x_windows = window_partition(shifted_x, self.window_size)
x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
attn_windows = self.attn(x_windows, mask=attn_mask)
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x
class PatchMerging(nn.Module):
def __init__(self, dim, device=None, dtype=None, operations=None):
super().__init__()
self.dim = dim
self.reduction = operations.Linear(4 * dim, 2 * dim, bias=False, device=device, dtype=dtype)
self.norm = operations.LayerNorm(4 * dim, device=device, dtype=dtype)
def forward(self, x, H, W):
B, L, C = x.shape
x = x.view(B, H, W, C)
# padding
pad_input = (H % 2 == 1) or (W % 2 == 1)
if pad_input:
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
class BasicLayer(nn.Module):
def __init__(self,
dim,
depth,
num_heads,
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
norm_layer=nn.LayerNorm,
downsample=None,
device=None, dtype=None, operations=None):
super().__init__()
self.window_size = window_size
self.shift_size = window_size // 2
self.depth = depth
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(
dim=dim,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
norm_layer=norm_layer,
device=device, dtype=dtype, operations=operations)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(dim=dim, device=device, dtype=dtype, operations=operations)
else:
self.downsample = None
def forward(self, x, H, W):
Hp = int(np.ceil(H / self.window_size)) * self.window_size
Wp = int(np.ceil(W / self.window_size)) * self.window_size
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size)
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
for blk in self.blocks:
blk.H, blk.W = H, W
x = blk(x, attn_mask)
if self.downsample is not None:
x_down = self.downsample(x, H, W)
Wh, Ww = (H + 1) // 2, (W + 1) // 2
return x, H, W, x_down, Wh, Ww
else:
return x, H, W, x, H, W
class PatchEmbed(nn.Module):
def __init__(self, patch_size=4, in_channels=3, embed_dim=96, norm_layer=None, device=None, dtype=None, operations=None):
super().__init__()
patch_size = (patch_size, patch_size)
self.patch_size = patch_size
self.in_channels = in_channels
self.embed_dim = embed_dim
self.proj = operations.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype)
if norm_layer is not None:
self.norm = norm_layer(embed_dim, device=device, dtype=dtype)
else:
self.norm = None
def forward(self, x):
_, _, H, W = x.size()
if W % self.patch_size[1] != 0:
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
if H % self.patch_size[0] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
x = self.proj(x) # B C Wh Ww
if self.norm is not None:
Wh, Ww = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
return x
class SwinTransformer(nn.Module):
def __init__(self,
pretrain_img_size=224,
patch_size=4,
in_channels=3,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
patch_norm=True,
out_indices=(0, 1, 2, 3),
frozen_stages=-1,
device=None, dtype=None, operations=None):
super().__init__()
norm_layer = partial(operations.LayerNorm, device=device, dtype=dtype)
self.pretrain_img_size = pretrain_img_size
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.patch_norm = patch_norm
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.patch_embed = PatchEmbed(
patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim,
device=device, dtype=dtype, operations=operations,
norm_layer=norm_layer if self.patch_norm else None)
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(
dim=int(embed_dim * 2 ** i_layer),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
device=device, dtype=dtype, operations=operations)
self.layers.append(layer)
num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
self.num_features = num_features
for i_layer in out_indices:
layer = norm_layer(num_features[i_layer])
layer_name = f'norm{i_layer}'
self.add_module(layer_name, layer)
def forward(self, x):
x = self.patch_embed(x)
Wh, Ww = x.size(2), x.size(3)
outs = []
x = x.flatten(2).transpose(1, 2)
for i in range(self.num_layers):
layer = self.layers[i]
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
x_out = norm_layer(x_out)
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
outs.append(out)
return tuple(outs)
class DeformableConv2d(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False, device=None, dtype=None, operations=None):
super(DeformableConv2d, self).__init__()
kernel_size = kernel_size if type(kernel_size) is tuple else (kernel_size, kernel_size)
self.stride = stride if type(stride) is tuple else (stride, stride)
self.padding = padding
self.offset_conv = operations.Conv2d(in_channels,
2 * kernel_size[0] * kernel_size[1],
kernel_size=kernel_size,
stride=stride,
padding=self.padding,
bias=True, device=device, dtype=dtype)
self.modulator_conv = operations.Conv2d(in_channels,
1 * kernel_size[0] * kernel_size[1],
kernel_size=kernel_size,
stride=stride,
padding=self.padding,
bias=True, device=device, dtype=dtype)
self.regular_conv = operations.Conv2d(in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=self.padding,
bias=bias, device=device, dtype=dtype)
def forward(self, x):
offset = self.offset_conv(x)
modulator = 2. * torch.sigmoid(self.modulator_conv(x))
weight, bias, offload_info = comfy.ops.cast_bias_weight(self.regular_conv, x, offloadable=True)
x = deform_conv2d(
input=x,
offset=offset,
weight=weight,
bias=None,
padding=self.padding,
mask=modulator,
stride=self.stride,
)
comfy.ops.uncast_bias_weight(self.regular_conv, weight, bias, offload_info)
return x
class BasicDecBlk(nn.Module):
def __init__(self, in_channels=64, out_channels=64, inter_channels=64, device=None, dtype=None, operations=None):
super(BasicDecBlk, self).__init__()
inter_channels = 64
self.conv_in = operations.Conv2d(in_channels, inter_channels, 3, 1, padding=1, device=device, dtype=dtype)
self.relu_in = nn.ReLU(inplace=True)
self.dec_att = ASPPDeformable(in_channels=inter_channels, device=device, dtype=dtype, operations=operations)
self.conv_out = operations.Conv2d(inter_channels, out_channels, 3, 1, padding=1, device=device, dtype=dtype)
self.bn_in = operations.BatchNorm2d(inter_channels, device=device, dtype=dtype)
self.bn_out = operations.BatchNorm2d(out_channels, device=device, dtype=dtype)
def forward(self, x):
x = self.conv_in(x)
x = self.bn_in(x)
x = self.relu_in(x)
x = self.dec_att(x)
x = self.conv_out(x)
x = self.bn_out(x)
return x
class BasicLatBlk(nn.Module):
def __init__(self, in_channels=64, out_channels=64, device=None, dtype=None, operations=None):
super(BasicLatBlk, self).__init__()
self.conv = operations.Conv2d(in_channels, out_channels, 1, 1, 0, device=device, dtype=dtype)
def forward(self, x):
x = self.conv(x)
return x
class _ASPPModuleDeformable(nn.Module):
def __init__(self, in_channels, planes, kernel_size, padding, device, dtype, operations):
super(_ASPPModuleDeformable, self).__init__()
self.atrous_conv = DeformableConv2d(in_channels, planes, kernel_size=kernel_size,
stride=1, padding=padding, bias=False, device=device, dtype=dtype, operations=operations)
self.bn = operations.BatchNorm2d(planes, device=device, dtype=dtype)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.atrous_conv(x)
x = self.bn(x)
return self.relu(x)
class ASPPDeformable(nn.Module):
def __init__(self, in_channels, out_channels=None, parallel_block_sizes=[1, 3, 7], device=None, dtype=None, operations=None):
super(ASPPDeformable, self).__init__()
self.down_scale = 1
if out_channels is None:
out_channels = in_channels
self.in_channelster = 256 // self.down_scale
self.aspp1 = _ASPPModuleDeformable(in_channels, self.in_channelster, 1, padding=0, device=device, dtype=dtype, operations=operations)
self.aspp_deforms = nn.ModuleList([
_ASPPModuleDeformable(in_channels, self.in_channelster, conv_size, padding=int(conv_size//2), device=device, dtype=dtype, operations=operations)
for conv_size in parallel_block_sizes
])
self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
operations.Conv2d(in_channels, self.in_channelster, 1, stride=1, bias=False, device=device, dtype=dtype),
operations.BatchNorm2d(self.in_channelster, device=device, dtype=dtype),
nn.ReLU(inplace=True))
self.conv1 = operations.Conv2d(self.in_channelster * (2 + len(self.aspp_deforms)), out_channels, 1, bias=False, device=device, dtype=dtype)
self.bn1 = operations.BatchNorm2d(out_channels, device=device, dtype=dtype)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x1 = self.aspp1(x)
x_aspp_deforms = [aspp_deform(x) for aspp_deform in self.aspp_deforms]
x5 = self.global_avg_pool(x)
x5 = F.interpolate(x5, size=x1.size()[2:], mode='bilinear', align_corners=True)
x = torch.cat((x1, *x_aspp_deforms, x5), dim=1)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
return x
class BiRefNet(nn.Module):
def __init__(self, config=None, dtype=None, device=None, operations=None):
super(BiRefNet, self).__init__()
self.bb = SwinTransformer(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12, device=device, dtype=dtype, operations=operations)
channels = [1536, 768, 384, 192]
channels = [c * 2 for c in channels]
self.cxt = channels[1:][::-1][-3:]
self.squeeze_module = nn.Sequential(*[
BasicDecBlk(channels[0]+sum(self.cxt), channels[0], device=device, dtype=dtype, operations=operations)
for _ in range(1)
])
self.decoder = Decoder(channels, device=device, dtype=dtype, operations=operations)
def forward_enc(self, x):
x1, x2, x3, x4 = self.bb(x)
B, C, H, W = x.shape
x1_, x2_, x3_, x4_ = self.bb(F.interpolate(x, size=(H//2, W//2), mode='bilinear', align_corners=True))
x1 = torch.cat([x1, F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True)], dim=1)
x2 = torch.cat([x2, F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True)], dim=1)
x3 = torch.cat([x3, F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True)], dim=1)
x4 = torch.cat([x4, F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True)], dim=1)
x4 = torch.cat(
(
*[
F.interpolate(x1, size=x4.shape[2:], mode='bilinear', align_corners=True),
F.interpolate(x2, size=x4.shape[2:], mode='bilinear', align_corners=True),
F.interpolate(x3, size=x4.shape[2:], mode='bilinear', align_corners=True),
][-len(CXT):],
x4
),
dim=1
)
return (x1, x2, x3, x4)
def forward_ori(self, x):
(x1, x2, x3, x4) = self.forward_enc(x)
x4 = self.squeeze_module(x4)
features = [x, x1, x2, x3, x4]
scaled_preds = self.decoder(features)
return scaled_preds
def forward(self, pixel_values, intermediate_output=None):
scaled_preds = self.forward_ori(pixel_values)
return scaled_preds
class Decoder(nn.Module):
def __init__(self, channels, device, dtype, operations):
super(Decoder, self).__init__()
# factory kwargs
fk = {"device":device, "dtype":dtype, "operations":operations}
DecoderBlock = partial(BasicDecBlk, **fk)
LateralBlock = partial(BasicLatBlk, **fk)
DBlock = partial(SimpleConvs, **fk)
self.split = True
N_dec_ipt = 64
ic = 64
ipt_cha_opt = 1
self.ipt_blk5 = DBlock(2**10*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic)
self.ipt_blk4 = DBlock(2**8*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic)
self.ipt_blk3 = DBlock(2**6*3 if self.split else 3, [N_dec_ipt, channels[1]//8][ipt_cha_opt], inter_channels=ic)
self.ipt_blk2 = DBlock(2**4*3 if self.split else 3, [N_dec_ipt, channels[2]//8][ipt_cha_opt], inter_channels=ic)
self.ipt_blk1 = DBlock(2**0*3 if self.split else 3, [N_dec_ipt, channels[3]//8][ipt_cha_opt], inter_channels=ic)
self.decoder_block4 = DecoderBlock(channels[0]+([N_dec_ipt, channels[0]//8][ipt_cha_opt]), channels[1])
self.decoder_block3 = DecoderBlock(channels[1]+([N_dec_ipt, channels[0]//8][ipt_cha_opt]), channels[2])
self.decoder_block2 = DecoderBlock(channels[2]+([N_dec_ipt, channels[1]//8][ipt_cha_opt]), channels[3])
self.decoder_block1 = DecoderBlock(channels[3]+([N_dec_ipt, channels[2]//8][ipt_cha_opt]), channels[3]//2)
fk = {"device":device, "dtype":dtype}
self.conv_out1 = nn.Sequential(operations.Conv2d(channels[3]//2+([N_dec_ipt, channels[3]//8][ipt_cha_opt]), 1, 1, 1, 0, **fk))
self.lateral_block4 = LateralBlock(channels[1], channels[1])
self.lateral_block3 = LateralBlock(channels[2], channels[2])
self.lateral_block2 = LateralBlock(channels[3], channels[3])
self.conv_ms_spvn_4 = operations.Conv2d(channels[1], 1, 1, 1, 0, **fk)
self.conv_ms_spvn_3 = operations.Conv2d(channels[2], 1, 1, 1, 0, **fk)
self.conv_ms_spvn_2 = operations.Conv2d(channels[3], 1, 1, 1, 0, **fk)
_N = 16
self.gdt_convs_4 = nn.Sequential(operations.Conv2d(channels[0] // 2, _N, 3, 1, 1, **fk), operations.BatchNorm2d(_N, **fk), nn.ReLU(inplace=True))
self.gdt_convs_3 = nn.Sequential(operations.Conv2d(channels[1] // 2, _N, 3, 1, 1, **fk), operations.BatchNorm2d(_N, **fk), nn.ReLU(inplace=True))
self.gdt_convs_2 = nn.Sequential(operations.Conv2d(channels[2] // 2, _N, 3, 1, 1, **fk), operations.BatchNorm2d(_N, **fk), nn.ReLU(inplace=True))
[setattr(self, f"gdt_convs_pred_{i}", nn.Sequential(operations.Conv2d(_N, 1, 1, 1, 0, **fk))) for i in range(2, 5)]
[setattr(self, f"gdt_convs_attn_{i}", nn.Sequential(operations.Conv2d(_N, 1, 1, 1, 0, **fk))) for i in range(2, 5)]
def get_patches_batch(self, x, p):
_size_h, _size_w = p.shape[2:]
patches_batch = []
for idx in range(x.shape[0]):
columns_x = torch.split(x[idx], split_size_or_sections=_size_w, dim=-1)
patches_x = []
for column_x in columns_x:
patches_x += [p.unsqueeze(0) for p in torch.split(column_x, split_size_or_sections=_size_h, dim=-2)]
patch_sample = torch.cat(patches_x, dim=1)
patches_batch.append(patch_sample)
return torch.cat(patches_batch, dim=0)
def forward(self, features):
x, x1, x2, x3, x4 = features
patches_batch = self.get_patches_batch(x, x4) if self.split else x
x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1)
p4 = self.decoder_block4(x4)
p4_gdt = self.gdt_convs_4(p4)
gdt_attn_4 = self.gdt_convs_attn_4(p4_gdt).sigmoid()
p4 = p4 * gdt_attn_4
_p4 = F.interpolate(p4, size=x3.shape[2:], mode='bilinear', align_corners=True)
_p3 = _p4 + self.lateral_block4(x3)
patches_batch = self.get_patches_batch(x, _p3) if self.split else x
_p3 = torch.cat((_p3, self.ipt_blk4(F.interpolate(patches_batch, size=x3.shape[2:], mode='bilinear', align_corners=True))), 1)
p3 = self.decoder_block3(_p3)
p3_gdt = self.gdt_convs_3(p3)
gdt_attn_3 = self.gdt_convs_attn_3(p3_gdt).sigmoid()
p3 = p3 * gdt_attn_3
_p3 = F.interpolate(p3, size=x2.shape[2:], mode='bilinear', align_corners=True)
_p2 = _p3 + self.lateral_block3(x2)
patches_batch = self.get_patches_batch(x, _p2) if self.split else x
_p2 = torch.cat((_p2, self.ipt_blk3(F.interpolate(patches_batch, size=x2.shape[2:], mode='bilinear', align_corners=True))), 1)
p2 = self.decoder_block2(_p2)
p2_gdt = self.gdt_convs_2(p2)
gdt_attn_2 = self.gdt_convs_attn_2(p2_gdt).sigmoid()
p2 = p2 * gdt_attn_2
_p2 = F.interpolate(p2, size=x1.shape[2:], mode='bilinear', align_corners=True)
_p1 = _p2 + self.lateral_block2(x1)
patches_batch = self.get_patches_batch(x, _p1) if self.split else x
_p1 = torch.cat((_p1, self.ipt_blk2(F.interpolate(patches_batch, size=x1.shape[2:], mode='bilinear', align_corners=True))), 1)
_p1 = self.decoder_block1(_p1)
_p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True)
patches_batch = self.get_patches_batch(x, _p1) if self.split else x
_p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1)
p1_out = self.conv_out1(_p1)
return p1_out
class SimpleConvs(nn.Module):
def __init__(
self, in_channels: int, out_channels: int, inter_channels=64, device=None, dtype=None, operations=None
) -> None:
super().__init__()
self.conv1 = operations.Conv2d(in_channels, inter_channels, 3, 1, 1, device=device, dtype=dtype)
self.conv_out = operations.Conv2d(inter_channels, out_channels, 3, 1, 1, device=device, dtype=dtype)
def forward(self, x):
return self.conv_out(self.conv1(x))

78
comfy/bg_removal_model.py Normal file
View File

@ -0,0 +1,78 @@
from .utils import load_torch_file
import os
import json
import torch
import logging
import comfy.ops
import comfy.model_patcher
import comfy.model_management
import comfy.clip_model
import comfy.background_removal.birefnet
BG_REMOVAL_MODELS = {
"birefnet": comfy.background_removal.birefnet.BiRefNet
}
class BackgroundRemovalModel():
def __init__(self, json_config):
with open(json_config) as f:
config = json.load(f)
self.image_size = config.get("image_size", 1024)
self.image_mean = config.get("image_mean", [0.0, 0.0, 0.0])
self.image_std = config.get("image_std", [1.0, 1.0, 1.0])
self.model_type = config.get("model_type", "birefnet")
self.config = config.copy()
model_class = BG_REMOVAL_MODELS.get(self.model_type)
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
self.model.eval()
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
def get_sd(self):
return self.model.state_dict()
def encode_image(self, image):
comfy.model_management.load_model_gpu(self.patcher)
H, W = image.shape[1], image.shape[2]
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=False)
out = self.model(pixel_values=pixel_values)
out = torch.nn.functional.interpolate(out, size=(H, W), mode="bicubic", antialias=False)
mask = out.sigmoid().to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
if mask.ndim == 3:
mask = mask.unsqueeze(0)
if mask.shape[1] != 1:
mask = mask.movedim(-1, 1)
return mask
def load_background_removal_model(sd):
if "bb.layers.1.blocks.0.attn.relative_position_index" in sd:
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "background_removal"), "birefnet.json")
else:
return None
bg_model = BackgroundRemovalModel(json_config)
m, u = bg_model.load_sd(sd)
if len(m) > 0:
logging.warning("missing background removal: {}".format(m))
u = set(u)
keys = list(sd.keys())
for k in keys:
if k not in u:
sd.pop(k)
return bg_model
def load(ckpt_path):
sd = load_torch_file(ckpt_path)
return load_background_removal_model(sd)

View File

@ -93,7 +93,7 @@ class Hook:
self.hook_scope = hook_scope
'''Scope of where this hook should apply in terms of the conds used in sampling run.'''
self.custom_should_register = default_should_register
'''Can be overriden with a compatible function to decide if this hook should be registered without the need to override .should_register'''
'''Can be overridden with a compatible function to decide if this hook should be registered without the need to override .should_register'''
@property
def strength(self):

View File

@ -140,7 +140,7 @@ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
alphas = alphacums[ddim_timesteps]
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
# according the the formula provided in https://arxiv.org/abs/2010.02502
# according to the formula provided in https://arxiv.org/abs/2010.02502
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
if verbose:
logging.info(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')

View File

@ -1135,7 +1135,7 @@ class AudioInjector_WAN(nn.Module):
self.injector_adain_output_layers = nn.ModuleList(
[operations.Linear(dim, dim, dtype=dtype, device=device) for _ in range(audio_injector_id)])
def forward(self, x, block_id, audio_emb, audio_emb_global, seq_len):
def forward(self, x, block_id, audio_emb, audio_emb_global, seq_len, scale=1.0):
audio_attn_id = self.injected_block_id.get(block_id, None)
if audio_attn_id is None:
return x
@ -1148,12 +1148,15 @@ class AudioInjector_WAN(nn.Module):
attn_hidden_states = adain_hidden_states
else:
attn_hidden_states = self.injector_pre_norm_feat[audio_attn_id](input_hidden_states)
audio_emb = rearrange(audio_emb, "b t n c -> (b t) n c", t=num_frames)
attn_audio_emb = audio_emb
if audio_emb.dim() == 3: # WanDancer case
attn_audio_emb = rearrange(audio_emb, "b t c -> (b t) 1 c", t=num_frames)
else: # S2V case
attn_audio_emb = rearrange(audio_emb, "b t n c -> (b t) n c", t=num_frames)
residual_out = self.injector[audio_attn_id](x=attn_hidden_states, context=attn_audio_emb)
residual_out = rearrange(
residual_out, "(b t) n c -> b (t n) c", t=num_frames)
x[:, :seq_len] = x[:, :seq_len] + residual_out
residual_out = rearrange(residual_out, "(b t) n c -> b (t n) c", t=num_frames)
x[:, :seq_len] = x[:, :seq_len] + residual_out * scale
return x

View File

@ -0,0 +1,251 @@
import torch
import torch.nn as nn
import comfy
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.math import apply_rope1
from comfy.ldm.flux.layers import EmbedND
from .model import AudioInjector_WAN, WanModel, MLPProj, Head, sinusoidal_embedding_1d
class MusicSelfAttention(nn.Module):
def __init__(self, dim, num_heads, device=None, dtype=None, operations=None):
assert dim % num_heads == 0
super().__init__()
self.embed_dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
self.k_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
self.v_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
self.out_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
def forward(self, x, freqs):
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
q = self.q_proj(x).view(b, s, n, d)
q = apply_rope1(q, freqs)
k = self.k_proj(x).view(b, s, n, d)
k = apply_rope1(k, freqs)
x = optimized_attention(
q.view(b, s, n * d),
k.view(b, s, n * d),
self.v_proj(x).view(b, s, n * d),
heads=self.num_heads,
)
return self.out_proj(x)
class MusicEncoderLayer(nn.Module):
def __init__(self, dim: int, num_heads: int, ffn_dim: int, device=None, dtype=None, operations=None):
super().__init__()
self.self_attn = MusicSelfAttention(dim, num_heads, device=device, dtype=dtype, operations=operations)
self.linear1 = operations.Linear(dim, ffn_dim, device=device, dtype=dtype)
self.linear2 = operations.Linear(ffn_dim, dim, device=device, dtype=dtype)
self.norm1 = operations.LayerNorm(dim, device=device, dtype=dtype)
self.norm2 = operations.LayerNorm(dim, device=device, dtype=dtype)
def forward(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
x = x + self.self_attn(self.norm1(x), freqs=freqs)
x = x + self.linear2(torch.nn.functional.gelu(self.linear1(self.norm2(x)))) # ffn
return x
class WanDancerModel(WanModel):
def __init__(self,
model_type='wandancer',
patch_size=(1, 2, 2),
text_len=512,
in_dim=16,
dim=5120,
ffn_dim=8192,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=16,
num_layers=40,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
in_dim_ref_conv=None,
image_model=None,
device=None, dtype=None, operations=None,
audio_inject_layers=[0, 4, 8, 12, 16, 20, 24, 27],
music_dim = 256,
music_heads = 4,
music_feature_dim = 35,
music_latent_dim = 256
):
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim,
num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model, in_dim_ref_conv=in_dim_ref_conv,
device=device, dtype=dtype, operations=operations)
self.dtype = dtype
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
self.patch_embedding_global = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=operation_settings.get("device"), dtype=torch.float32)
self.img_emb_refimage = MLPProj(1280, dim, operation_settings=operation_settings)
self.head_global = Head(dim, out_dim, patch_size, eps, operation_settings=operation_settings)
self.music_injector = AudioInjector_WAN(
dim=self.dim,
num_heads=self.num_heads,
inject_layer=audio_inject_layers,
root_net=self,
enable_adain=False,
dtype=dtype, device=device, operations=operations
)
self.music_projection = operations.Linear(music_feature_dim, music_latent_dim, device=device, dtype=dtype)
self.music_encoder = nn.ModuleList([MusicEncoderLayer(dim=music_dim, num_heads=music_heads, ffn_dim=1024, device=device, dtype=dtype, operations=operations) for _ in range(2)])
music_head_dim = music_dim // music_heads
self.music_rope_embedder = EmbedND(dim=music_head_dim, theta=10000.0, axes_dim=[music_head_dim])
def forward_orig(self, x, t, context, clip_fea=None, clip_fea_ref=None, freqs=None, audio_embed=None, fps=30, audio_inject_scale=1.0, transformer_options={}, **kwargs):
# embeddings
if int(fps + 0.5) != 30:
x = self.patch_embedding_global(x.float()).to(x.dtype)
else:
x = self.patch_embedding(x.float()).to(x.dtype)
grid_sizes = x.shape[2:]
latent_frames = grid_sizes[0]
transformer_options["grid_sizes"] = grid_sizes
x = x.flatten(2).transpose(1, 2)
seq_len = x.size(1)
# time embeddings
e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
e = e.reshape(t.shape[0], -1, e.shape[-1])
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
full_ref = None
if self.ref_conv is not None: # model has the weight, but this wasn't used in the original pipeline
full_ref = kwargs.get("reference_latent", None)
if full_ref is not None:
full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2)
x = torch.concat((full_ref, x), dim=1)
# context
context = self.text_embedding(context)
audio_emb = None
if audio_embed is not None: # encode music feature[1, frame_num, 35] -> [1, F*8, dim]
music_feature = self.music_projection(audio_embed)
music_seq_len = music_feature.shape[1]
music_ids = torch.arange(music_seq_len, device=music_feature.device, dtype=music_feature.dtype).reshape(1, -1, 1) # create 1D position IDs
music_freqs = self.music_rope_embedder(music_ids).movedim(1, 2)
# apply encoder layers
for layer in self.music_encoder:
music_feature = layer(music_feature, music_freqs)
# interpolate
audio_emb = torch.nn.functional.interpolate(music_feature.unsqueeze(1), size=(latent_frames * 8, self.dim), mode='bilinear').squeeze(1)
context_img_len = 0
if self.img_emb is not None and clip_fea is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.cat([context_clip, context], dim=1)
context_img_len += clip_fea.shape[-2]
if self.img_emb_refimage is not None and clip_fea_ref is not None:
context_clip_ref = self.img_emb_refimage(clip_fea_ref)
context = torch.cat([context_clip_ref, context], dim=1)
context_img_len += clip_fea_ref.shape[-2]
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
if audio_emb is not None:
x = self.music_injector(x, i, audio_emb, audio_emb_global=None, seq_len=seq_len, scale=audio_inject_scale)
# head
if int(fps + 0.5) != 30:
x = self.head_global(x, e)
else:
x = self.head(x, e)
if full_ref is not None:
x = x[:, full_ref.shape[1]:]
# unpatchify
x = self.unpatchify(x, grid_sizes)
return x
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, clip_fea_ref=None, fps=30, audio_inject_scale=1.0, **kwargs):
bs, c, t, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
t_len = t
if time_dim_concat is not None:
time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size)
x = torch.cat([x, time_dim_concat], dim=2)
t_len = x.shape[2]
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, fps=fps, transformer_options=transformer_options)
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, clip_fea_ref=clip_fea_ref, freqs=freqs, fps=fps, audio_inject_scale=audio_inject_scale, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, fps=30, device=None, dtype=None, transformer_options={}):
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
if steps_t is None:
steps_t = t_len
if steps_h is None:
steps_h = h_len
if steps_w is None:
steps_w = w_len
h_start = 0
w_start = 0
rope_options = transformer_options.get("rope_options", None)
if rope_options is not None:
t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
t_start += rope_options.get("shift_t", 0.0)
h_start += rope_options.get("shift_y", 0.0)
w_start += rope_options.get("shift_x", 0.0)
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
if int(fps + 0.5) != 30:
time_scale = 30.0 / fps # how many time units each frame represents relative to 30fps
positions_new = torch.arange(steps_t, device=device, dtype=dtype) * time_scale + t_start
total_frames_at_30fps = int(time_scale * steps_t + 0.5)
positions_new[-1] = t_start + (total_frames_at_30fps - 1)
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + positions_new.reshape(-1, 1, 1)
else:
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
freqs = self.rope_embedder(img_ids).movedim(1, 2)
return freqs

View File

@ -43,6 +43,7 @@ import comfy.ldm.lumina.model
import comfy.ldm.wan.model
import comfy.ldm.wan.model_animate
import comfy.ldm.wan.ar_model
import comfy.ldm.wan.model_wandancer
import comfy.ldm.hunyuan3d.model
import comfy.ldm.hidream.model
import comfy.ldm.chroma.model
@ -1599,6 +1600,30 @@ class WAN21_SCAIL(WAN21):
return out
class WAN22_WanDancer(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=True, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_wandancer.WanDancerModel)
self.image_to_video = image_to_video
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
audio_embed = kwargs.get("audio_embed", None)
if audio_embed is not None:
out['audio_embed'] = comfy.conds.CONDRegular(audio_embed)
clip_vision_output_ref = kwargs.get("clip_vision_output_ref", None)
if clip_vision_output_ref is not None:
out['clip_fea_ref'] = comfy.conds.CONDRegular(clip_vision_output_ref.penultimate_hidden_states)
fps = kwargs.get("fps", None)
if fps is not None:
out['fps'] = comfy.conds.CONDRegular(torch.FloatTensor([fps]))
audio_inject_scale = kwargs.get("audio_inject_scale", None)
if audio_inject_scale is not None:
out['audio_inject_scale'] = comfy.conds.CONDRegular(torch.FloatTensor([audio_inject_scale]))
return out
class Hunyuan3Dv2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)

View File

@ -572,6 +572,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["model_type"] = "animate"
elif '{}patch_embedding_pose.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "scail"
elif '{}patch_embedding_global.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "wandancer"
else:
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "i2v"

View File

@ -562,6 +562,25 @@ class disable_weight_init:
else:
return super().forward(*args, **kwargs)
class BatchNorm2d(torch.nn.BatchNorm2d, CastWeightBiasOp):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
running_mean = self.running_mean.to(device=input.device, dtype=weight.dtype) if self.running_mean is not None else None
running_var = self.running_var.to(device=input.device, dtype=weight.dtype) if self.running_var is not None else None
x = torch.nn.functional.batch_norm(input, running_mean, running_var, weight, bias, self.training, self.momentum, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
def reset_parameters(self):
return None
@ -749,6 +768,9 @@ class manual_cast(disable_weight_init):
class Conv3d(disable_weight_init.Conv3d):
comfy_cast_weights = True
class BatchNorm2d(disable_weight_init.BatchNorm2d):
comfy_cast_weights = True
class GroupNorm(disable_weight_init.GroupNorm):
comfy_cast_weights = True

View File

@ -1313,6 +1313,37 @@ class WAN21_SCAIL(WAN21_T2V):
out = model_base.WAN21_SCAIL(self, image_to_video=False, device=device)
return out
class WAN22_WanDancer(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "wandancer",
"in_dim": 36,
}
def __init__(self, unet_config):
super().__init__(unet_config)
self.memory_usage_factor = 1.8
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN22_WanDancer(self, image_to_video=True, device=device)
return out
def process_unet_state_dict(self, state_dict):
out_sd = {}
for k in list(state_dict.keys()):
# split music_encoder in_proj into q_proj, k_proj, v_proj
if "music_encoder" in k and "self_attn.in_proj" in k:
suffix = "weight" if k.endswith("weight") else "bias"
tensor = state_dict[k]
d = tensor.shape[0] // 3
prefix = k.replace(f"in_proj_{suffix}", "")
out_sd[f"{prefix}q_proj.{suffix}"] = tensor[:d]
out_sd[f"{prefix}k_proj.{suffix}"] = tensor[d:2*d]
out_sd[f"{prefix}v_proj.{suffix}"] = tensor[2*d:]
else:
out_sd[k] = state_dict[k]
return out_sd
class Hunyuan3Dv2(supported_models_base.BASE):
unet_config = {
"image_model": "hunyuan3d2",
@ -1982,6 +2013,7 @@ models = [
WAN22_Animate,
WAN21_FlowRVS,
WAN21_SCAIL,
WAN22_WanDancer,
Hunyuan3Dv2mini,
Hunyuan3Dv2,
Hunyuan3Dv2_1,

View File

@ -1390,7 +1390,7 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
k_out = "{}.weight_scale".format(layer)
if layer is not None:
layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints
layer_conf = {"format": "float8_e4m3fn"}
if full_precision_matrix_mult:
layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult
layers[layer] = layer_conf

View File

@ -17,6 +17,7 @@ if TYPE_CHECKING:
from spandrel import ImageModelDescriptor
from comfy.clip_vision import ClipVisionModel
from comfy.clip_vision import Output as ClipVisionOutput_
from comfy.bg_removal_model import BackgroundRemovalModel
from comfy.controlnet import ControlNet
from comfy.hooks import HookGroup, HookKeyframeGroup
from comfy.model_patcher import ModelPatcher
@ -614,6 +615,11 @@ class Model(ComfyTypeIO):
if TYPE_CHECKING:
Type = ModelPatcher
@comfytype(io_type="BACKGROUND_REMOVAL")
class BackgroundRemoval(ComfyTypeIO):
if TYPE_CHECKING:
Type = BackgroundRemovalModel
@comfytype(io_type="CLIP_VISION")
class ClipVision(ComfyTypeIO):
if TYPE_CHECKING:
@ -2257,6 +2263,7 @@ __all__ = [
"ModelPatch",
"ClipVision",
"ClipVisionOutput",
"BackgroundRemoval",
"AudioEncoder",
"AudioEncoderOutput",
"StyleModel",

View File

@ -1,10 +1,11 @@
from __future__ import annotations
from enum import Enum
from typing import Optional, List, Dict, Any, Union
from typing import Optional, Any
from pydantic import BaseModel, Field, RootModel
class TripoModelVersion(str, Enum):
v3_1_20260211 = 'v3.1-20260211'
v3_0_20250812 = 'v3.0-20250812'
v2_5_20250123 = 'v2.5-20250123'
v2_0_20240919 = 'v2.0-20240919'
@ -142,7 +143,7 @@ class TripoFileEmptyReference(BaseModel):
pass
class TripoFileReference(RootModel):
root: Union[TripoFileTokenReference, TripoUrlReference, TripoObjectReference, TripoFileEmptyReference]
root: TripoFileTokenReference | TripoUrlReference | TripoObjectReference | TripoFileEmptyReference
class TripoGetStsTokenRequest(BaseModel):
format: str = Field(..., description='The format of the image')
@ -183,7 +184,7 @@ class TripoImageToModelRequest(BaseModel):
class TripoMultiviewToModelRequest(BaseModel):
type: TripoTaskType = TripoTaskType.MULTIVIEW_TO_MODEL
files: List[TripoFileReference] = Field(..., description='The file references to convert to a model')
files: list[TripoFileReference] = Field(..., description='The file references to convert to a model')
model_version: Optional[TripoModelVersion] = Field(None, description='The model version to use for generation')
orthographic_projection: Optional[bool] = Field(False, description='Whether to use orthographic projection')
face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to')
@ -251,27 +252,13 @@ class TripoConvertModelRequest(BaseModel):
with_animation: Optional[bool] = Field(None, description='Whether to include animations')
pack_uv: Optional[bool] = Field(None, description='Whether to pack the UVs')
bake: Optional[bool] = Field(None, description='Whether to bake the model')
part_names: Optional[List[str]] = Field(None, description='The names of the parts to include')
part_names: Optional[list[str]] = Field(None, description='The names of the parts to include')
fbx_preset: Optional[TripoFbxPreset] = Field(None, description='The preset for the FBX export')
export_vertex_colors: Optional[bool] = Field(None, description='Whether to export the vertex colors')
export_orientation: Optional[TripoOrientation] = Field(None, description='The orientation for the export')
animate_in_place: Optional[bool] = Field(None, description='Whether to animate in place')
class TripoTaskRequest(RootModel):
root: Union[
TripoTextToModelRequest,
TripoImageToModelRequest,
TripoMultiviewToModelRequest,
TripoTextureModelRequest,
TripoRefineModelRequest,
TripoAnimatePrerigcheckRequest,
TripoAnimateRigRequest,
TripoAnimateRetargetRequest,
TripoStylizeModelRequest,
TripoConvertModelRequest
]
class TripoTaskOutput(BaseModel):
model: Optional[str] = Field(None, description='URL to the model')
base_model: Optional[str] = Field(None, description='URL to the base model')
@ -283,12 +270,13 @@ class TripoTask(BaseModel):
task_id: str = Field(..., description='The task ID')
type: Optional[str] = Field(None, description='The type of task')
status: Optional[TripoTaskStatus] = Field(None, description='The status of the task')
input: Optional[Dict[str, Any]] = Field(None, description='The input parameters for the task')
input: Optional[dict[str, Any]] = Field(None, description='The input parameters for the task')
output: Optional[TripoTaskOutput] = Field(None, description='The output of the task')
progress: Optional[int] = Field(None, description='The progress of the task', ge=0, le=100)
create_time: Optional[int] = Field(None, description='The creation time of the task')
running_left_time: Optional[int] = Field(None, description='The estimated time left for the task')
queue_position: Optional[int] = Field(None, description='The position in the queue')
consumed_credit: int | None = Field(None)
class TripoTaskResponse(BaseModel):
code: int = Field(0, description='The response code')
@ -296,7 +284,7 @@ class TripoTaskResponse(BaseModel):
class TripoGeneralResponse(BaseModel):
code: int = Field(0, description='The response code')
data: Dict[str, str] = Field(..., description='The task ID data')
data: dict[str, str] = Field(..., description='The task ID data')
class TripoBalanceData(BaseModel):
balance: float = Field(..., description='The account balance')

View File

@ -1271,7 +1271,7 @@ PRICE_BADGE_VIDEO = IO.PriceBadge(
)
def _seedance2_text_inputs(resolutions: list[str]):
def _seedance2_text_inputs(resolutions: list[str], default_ratio: str = "16:9"):
return [
IO.String.Input(
"prompt",
@ -1287,6 +1287,7 @@ def _seedance2_text_inputs(resolutions: list[str]):
IO.Combo.Input(
"ratio",
options=["16:9", "4:3", "1:1", "3:4", "9:16", "21:9", "adaptive"],
default=default_ratio,
tooltip="Aspect ratio of the output video.",
),
IO.Int.Input(
@ -1420,8 +1421,14 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs(["480p", "720p", "1080p"])),
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs(["480p", "720p"])),
IO.DynamicCombo.Option(
"Seedance 2.0",
_seedance2_text_inputs(["480p", "720p", "1080p"], default_ratio="adaptive"),
),
IO.DynamicCombo.Option(
"Seedance 2.0 Fast",
_seedance2_text_inputs(["480p", "720p"], default_ratio="adaptive"),
),
],
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
),
@ -1588,9 +1595,9 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
def _seedance2_reference_inputs(resolutions: list[str]):
def _seedance2_reference_inputs(resolutions: list[str], default_ratio: str = "16:9"):
return [
*_seedance2_text_inputs(resolutions),
*_seedance2_text_inputs(resolutions, default_ratio=default_ratio),
IO.Autogrow.Input(
"reference_images",
template=IO.Autogrow.TemplateNames(
@ -1668,8 +1675,14 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_reference_inputs(["480p", "720p", "1080p"])),
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_reference_inputs(["480p", "720p"])),
IO.DynamicCombo.Option(
"Seedance 2.0",
_seedance2_reference_inputs(["480p", "720p", "1080p"], default_ratio="adaptive"),
),
IO.DynamicCombo.Option(
"Seedance 2.0 Fast",
_seedance2_reference_inputs(["480p", "720p"], default_ratio="adaptive"),
),
],
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
),

View File

@ -60,6 +60,7 @@ async def poll_until_finished(
],
status_extractor=lambda x: x.data.status,
progress_extractor=lambda x: x.data.progress,
price_extractor=lambda x: x.data.consumed_credit * 0.01 if x.data.consumed_credit else None,
estimated_duration=average_duration,
)
if response_poll.data.status == TripoTaskStatus.SUCCESS:
@ -113,7 +114,6 @@ class TripoTextToModelNode(IO.ComfyNode):
depends_on=IO.PriceBadgeDepends(
widgets=[
"model_version",
"style",
"texture",
"pbr",
"quad",
@ -124,20 +124,17 @@ class TripoTextToModelNode(IO.ComfyNode):
expr="""
(
$isV14 := $contains(widgets.model_version,"v1.4");
$style := widgets.style;
$hasStyle := ($style != "" and $style != "none");
$isV3OrLater := $contains(widgets.model_version,"v3.");
$withTexture := widgets.texture or widgets.pbr;
$isHdTexture := (widgets.texture_quality = "detailed");
$isDetailedGeometry := (widgets.geometry_quality = "detailed");
$baseCredits :=
$isV14 ? 20 : ($withTexture ? 20 : 10);
$credits :=
$baseCredits
+ ($hasStyle ? 5 : 0)
$credits := $isV14 ? 20 : (
($withTexture ? 20 : 10)
+ (widgets.quad ? 5 : 0)
+ ($isHdTexture ? 10 : 0)
+ ($isDetailedGeometry ? 20 : 0);
{"type":"usd","usd": $round($credits * 0.01, 2)}
+ (($isDetailedGeometry and $isV3OrLater) ? 20 : 0)
);
{"type":"usd","usd": $round($credits * 0.01, 2), "format": {"approximate": true}}
)
""",
),
@ -239,7 +236,6 @@ class TripoImageToModelNode(IO.ComfyNode):
depends_on=IO.PriceBadgeDepends(
widgets=[
"model_version",
"style",
"texture",
"pbr",
"quad",
@ -250,20 +246,17 @@ class TripoImageToModelNode(IO.ComfyNode):
expr="""
(
$isV14 := $contains(widgets.model_version,"v1.4");
$style := widgets.style;
$hasStyle := ($style != "" and $style != "none");
$isV3OrLater := $contains(widgets.model_version,"v3.");
$withTexture := widgets.texture or widgets.pbr;
$isHdTexture := (widgets.texture_quality = "detailed");
$isDetailedGeometry := (widgets.geometry_quality = "detailed");
$baseCredits :=
$isV14 ? 30 : ($withTexture ? 30 : 20);
$credits :=
$baseCredits
+ ($hasStyle ? 5 : 0)
$credits := $isV14 ? 30 : (
($withTexture ? 30 : 20)
+ (widgets.quad ? 5 : 0)
+ ($isHdTexture ? 10 : 0)
+ ($isDetailedGeometry ? 20 : 0);
{"type":"usd","usd": $round($credits * 0.01, 2)}
+ (($isDetailedGeometry and $isV3OrLater) ? 20 : 0)
);
{"type":"usd","usd": $round($credits * 0.01, 2), "format": {"approximate": true}}
)
""",
),
@ -358,7 +351,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
"texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True, advanced=True
),
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True, advanced=True),
IO.Boolean.Input("quad", default=False, optional=True, advanced=True),
IO.Boolean.Input("quad", default=False, optional=True, advanced=True, tooltip="This parameter is deprecated and does nothing."),
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True),
],
outputs=[
@ -379,7 +372,6 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
"model_version",
"texture",
"pbr",
"quad",
"texture_quality",
"geometry_quality",
],
@ -387,17 +379,16 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
expr="""
(
$isV14 := $contains(widgets.model_version,"v1.4");
$isV3OrLater := $contains(widgets.model_version,"v3.");
$withTexture := widgets.texture or widgets.pbr;
$isHdTexture := (widgets.texture_quality = "detailed");
$isDetailedGeometry := (widgets.geometry_quality = "detailed");
$baseCredits :=
$isV14 ? 30 : ($withTexture ? 30 : 20);
$credits :=
$baseCredits
+ (widgets.quad ? 5 : 0)
$credits := $isV14 ? 30 : (
($withTexture ? 30 : 20)
+ ($isHdTexture ? 10 : 0)
+ ($isDetailedGeometry ? 20 : 0);
{"type":"usd","usd": $round($credits * 0.01, 2)}
+ (($isDetailedGeometry and $isV3OrLater) ? 20 : 0)
);
{"type":"usd","usd": $round($credits * 0.01, 2), "format": {"approximate": true}}
)
""",
),
@ -457,7 +448,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
geometry_quality=geometry_quality,
texture_alignment=texture_alignment,
face_limit=face_limit if face_limit != -1 else None,
quad=quad,
quad=None,
),
)
return await poll_until_finished(cls, response, average_duration=80)
@ -498,7 +489,7 @@ class TripoTextureNode(IO.ComfyNode):
expr="""
(
$tq := widgets.texture_quality;
{"type":"usd","usd": ($contains($tq,"detailed") ? 0.2 : 0.1)}
{"type":"usd","usd": ($contains($tq,"detailed") ? 0.2 : 0.1), "format": {"approximate": true}}
)
""",
),
@ -555,7 +546,7 @@ class TripoRefineNode(IO.ComfyNode):
is_api_node=True,
is_output_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.3}""",
expr="""{"type":"usd","usd":0.3, "format": {"approximate": true}}""",
),
)
@ -592,7 +583,7 @@ class TripoRigNode(IO.ComfyNode):
is_api_node=True,
is_output_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.25}""",
expr="""{"type":"usd","usd":0.25, "format": {"approximate": true}}""",
),
)
@ -652,7 +643,7 @@ class TripoRetargetNode(IO.ComfyNode):
is_api_node=True,
is_output_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.1}""",
expr="""{"type":"usd","usd":0.1, "format": {"approximate": true}}""",
),
)
@ -761,19 +752,10 @@ class TripoConversionNode(IO.ComfyNode):
"face_limit",
"texture_size",
"texture_format",
"force_symmetry",
"flatten_bottom",
"flatten_bottom_threshold",
"pivot_to_center_bottom",
"scale_factor",
"with_animation",
"pack_uv",
"bake",
"part_names",
"fbx_preset",
"export_vertex_colors",
"export_orientation",
"animate_in_place",
],
),
expr="""
@ -783,28 +765,16 @@ class TripoConversionNode(IO.ComfyNode):
$flatThresh := (widgets.flatten_bottom_threshold != null) ? widgets.flatten_bottom_threshold : 0;
$scale := (widgets.scale_factor != null) ? widgets.scale_factor : 1;
$texFmt := (widgets.texture_format != "" ? widgets.texture_format : "jpeg");
$part := widgets.part_names;
$fbx := (widgets.fbx_preset != "" ? widgets.fbx_preset : "blender");
$orient := (widgets.export_orientation != "" ? widgets.export_orientation : "default");
$advanced :=
widgets.quad or
widgets.force_symmetry or
widgets.flatten_bottom or
widgets.pivot_to_center_bottom or
widgets.with_animation or
widgets.pack_uv or
widgets.bake or
widgets.export_vertex_colors or
widgets.animate_in_place or
($face != -1) or
($texSize != 4096) or
($flatThresh != 0) or
($scale != 1) or
($texFmt != "jpeg") or
($part != "") or
($fbx != "blender") or
($orient != "default");
{"type":"usd","usd": ($advanced ? 0.1 : 0.05)}
($texFmt != "jpeg");
{"type":"usd","usd": ($advanced ? 0.1 : 0.05), "format": {"approximate": true}}
)
""",
),

View File

@ -488,10 +488,30 @@ async def _diagnose_connectivity() -> dict[str, bool]:
"api_accessible": False,
}
timeout = aiohttp.ClientTimeout(total=5.0)
# Probe Google and Baidu in parallel: Google is blocked by the GFW in mainland China, so a Baidu probe is required
# to correctly detect that Chinese users with working internet do have working internet.
internet_probe_urls = ("https://www.google.com", "https://www.baidu.com")
async with aiohttp.ClientSession(timeout=timeout) as session:
with contextlib.suppress(ClientError, OSError):
async with session.get("https://www.google.com") as resp:
results["internet_accessible"] = resp.status < 500
async def _probe(url: str) -> bool:
try:
async with session.get(url) as resp:
return resp.status < 500
except (ClientError, OSError, asyncio.TimeoutError):
return False
probe_tasks = [asyncio.create_task(_probe(u)) for u in internet_probe_urls]
try:
for fut in asyncio.as_completed(probe_tasks):
if await fut:
results["internet_accessible"] = True
break
finally:
for t in probe_tasks:
if not t.done():
t.cancel()
await asyncio.gather(*probe_tasks, return_exceptions=True)
if not results["internet_accessible"]:
return results

View File

@ -0,0 +1,60 @@
import folder_paths
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO
from comfy.bg_removal_model import load
class LoadBackgroundRemovalModel(IO.ComfyNode):
@classmethod
def define_schema(cls):
files = folder_paths.get_filename_list("background_removal")
return IO.Schema(
node_id="LoadBackgroundRemovalModel",
display_name="Load Background Removal Model",
category="loaders",
inputs=[
IO.Combo.Input("bg_removal_name", options=sorted(files), tooltip="The model used to remove backgrounds from images"),
],
outputs=[
IO.BackgroundRemoval.Output("bg_model")
]
)
@classmethod
def execute(cls, bg_removal_name):
path = folder_paths.get_full_path_or_raise("background_removal", bg_removal_name)
bg = load(path)
if bg is None:
raise RuntimeError("ERROR: background model file is invalid and does not contain a valid background removal model.")
return IO.NodeOutput(bg)
class RemoveBackground(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="RemoveBackground",
display_name="Remove Background",
category="image/background removal",
inputs=[
IO.Image.Input("image", tooltip="Input image to remove the background from"),
IO.BackgroundRemoval.Input("bg_removal_model", tooltip="Background removal model used to generate the mask")
],
outputs=[
IO.Mask.Output("mask", tooltip="Generated foreground mask")
]
)
@classmethod
def execute(cls, image, bg_removal_model):
mask = bg_removal_model.encode_image(image)
return IO.NodeOutput(mask)
class BackgroundRemovalExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
LoadBackgroundRemovalModel,
RemoveBackground
]
async def comfy_entrypoint() -> BackgroundRemovalExtension:
return BackgroundRemovalExtension()

View File

@ -203,7 +203,7 @@ class JoinImageWithAlpha(io.ComfyNode):
@classmethod
def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput:
batch_size = max(len(image), len(alpha))
alpha = 1.0 - resize_mask(alpha, image.shape[1:])
alpha = 1.0 - resize_mask(alpha.to(image), image.shape[1:])
alpha = comfy.utils.repeat_to_batch_size(alpha, batch_size)
image = comfy.utils.repeat_to_batch_size(image, batch_size)
return io.NodeOutput(torch.cat((image[..., :3], alpha.unsqueeze(-1)), dim=-1))

View File

@ -102,7 +102,7 @@ class FluxDisableGuidance(io.ComfyNode):
append = execute # TODO: remove
PREFERED_KONTEXT_RESOLUTIONS = [
PREFERRED_KONTEXT_RESOLUTIONS = [
(672, 1568),
(688, 1504),
(720, 1456),
@ -143,7 +143,7 @@ class FluxKontextImageScale(io.ComfyNode):
width = image.shape[2]
height = image.shape[1]
aspect_ratio = width / height
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS)
image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1)
return io.NodeOutput(image)

View File

@ -106,12 +106,12 @@ class LTXVImgToVideoInplace(io.ComfyNode):
if bypass:
return (latent,)
samples = latent["samples"]
samples = latent["samples"].clone()
_, height_scale_factor, width_scale_factor = (
vae.downscale_index_formula
)
batch, _, latent_frames, latent_height, latent_width = samples.shape
_, _, _, latent_height, latent_width = samples.shape
width = latent_width * width_scale_factor
height = latent_height * height_scale_factor
@ -124,11 +124,7 @@ class LTXVImgToVideoInplace(io.ComfyNode):
samples[:, :, :t.shape[2]] = t
conditioning_latent_frames_mask = torch.ones(
(batch, 1, latent_frames, 1, 1),
dtype=torch.float32,
device=samples.device,
)
conditioning_latent_frames_mask = get_noise_mask(latent)
conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength
return io.NodeOutput({"samples": samples, "noise_mask": conditioning_latent_frames_mask})
@ -236,7 +232,7 @@ class LTXVAddGuide(io.ComfyNode):
def encode(cls, vae, latent_width, latent_height, images, scale_factors):
time_scale_factor, width_scale_factor, height_scale_factor = scale_factors
images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1]
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="disabled").movedim(1, -1)
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="center").movedim(1, -1)
encode_pixels = pixels[:, :, :, :3]
t = vae.encode(encode_pixels)
return encode_pixels, t

View File

@ -40,10 +40,21 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
inverse_mask = torch.ones_like(mask) - mask
source_portion = mask * source[..., :visible_height, :visible_width]
destination_portion = inverse_mask * destination[..., top:bottom, left:right]
source_rgb = source[:, :3, :visible_height, :visible_width]
dest_slice = destination[..., top:bottom, left:right]
if destination.shape[1] == 4:
if torch.max(dest_slice) == 0:
destination[:, :3, top:bottom, left:right] = source_rgb
destination[:, 3:4, top:bottom, left:right] = mask
else:
destination[:, :3, top:bottom, left:right] = (mask * source_rgb) + (inverse_mask * dest_slice[:, :3])
destination[:, 3:4, top:bottom, left:right] = torch.max(mask, dest_slice[:, 3:4])
else:
source_portion = mask * source_rgb
destination_portion = inverse_mask * dest_slice
destination[..., top:bottom, left:right] = source_portion + destination_portion
destination[..., top:bottom, left:right] = source_portion + destination_portion
return destination
class LatentCompositeMasked(IO.ComfyNode):
@ -84,18 +95,23 @@ class ImageCompositeMasked(IO.ComfyNode):
display_name="Image Composite Masked",
category="image",
inputs=[
IO.Image.Input("destination"),
IO.Image.Input("source"),
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
IO.Boolean.Input("resize_source", default=False),
IO.Image.Input("destination", optional=True),
IO.Mask.Input("mask", optional=True),
],
outputs=[IO.Image.Output()],
)
@classmethod
def execute(cls, destination, source, x, y, resize_source, mask = None) -> IO.NodeOutput:
def execute(cls, source, x, y, resize_source, destination = None, mask = None) -> IO.NodeOutput:
if destination is None: # transparent rgba
B, H, W, C = source.shape
destination = torch.zeros((B, H, W, 4), dtype=source.dtype, device=source.device)
if C == 3:
source = torch.nn.functional.pad(source, (0, 1), value=1.0)
destination, source = node_helpers.image_alpha_fix(destination, source)
destination = destination.clone().movedim(-1, 1)
output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1)
@ -381,7 +397,6 @@ class GrowMask(IO.ComfyNode):
expand_mask = execute # TODO: remove
class ThresholdMask(IO.ComfyNode):
@classmethod
def define_schema(cls):

View File

@ -63,7 +63,7 @@ class MathExpressionNode(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
autogrow = io.Autogrow.TemplateNames(
input=io.MultiType.Input("value", [io.Float, io.Int]),
input=io.MultiType.Input("value", [io.Float, io.Int, io.Boolean]),
names=list(string.ascii_lowercase),
min=1,
)
@ -82,6 +82,7 @@ class MathExpressionNode(io.ComfyNode):
outputs=[
io.Float.Output(display_name="FLOAT"),
io.Int.Output(display_name="INT"),
io.Boolean.Output(display_name="BOOL"),
],
)
@ -97,7 +98,7 @@ class MathExpressionNode(io.ComfyNode):
result = simple_eval(expression, names=context, functions=MATH_FUNCTIONS)
# bool check must come first because bool is a subclass of int in Python
if isinstance(result, bool) or not isinstance(result, (int, float)):
if not isinstance(result, (int, float)):
raise ValueError(
f"Math Expression '{expression}' must evaluate to a numeric result, "
f"got {type(result).__name__}: {result!r}"
@ -106,7 +107,7 @@ class MathExpressionNode(io.ComfyNode):
raise ValueError(
f"Math Expression '{expression}' produced a non-finite result: {result}"
)
return io.NodeOutput(float(result), int(result))
return io.NodeOutput(float(result), int(result), bool(result))
class MathExtension(ComfyExtension):

View File

@ -0,0 +1,971 @@
import math
import nodes
import node_helpers
import torch
import torchaudio
import comfy.model_management
import comfy.utils
import numpy as np
import logging
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
import scipy.signal
import scipy.ndimage
import scipy.fft
import scipy.sparse
# Audio Processing Functions - Derived from librosa (https://github.com/librosa/librosa)
# Copyright (c) 2013--2023, librosa development team.
def mel_to_hz(mels, htk=False):
"""Convert mel to Hz (slaney)"""
mels = np.asanyarray(mels)
if htk:
return 700.0 * (10.0 ** (mels / 2595.0) - 1.0)
f_min = 0.0
f_sp = 200.0 / 3
freqs = f_min + f_sp * mels
min_log_hz = 1000.0
min_log_mel = (min_log_hz - f_min) / f_sp
logstep = np.log(6.4) / 27.0
if mels.ndim:
log_t = mels >= min_log_mel
freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel))
elif mels >= min_log_mel:
freqs = min_log_hz * np.exp(logstep * (mels - min_log_mel))
return freqs
def hz_to_mel(frequencies, htk=False):
"""Convert Hz to mel (slaney)"""
frequencies = np.asanyarray(frequencies)
if htk:
return 2595.0 * np.log10(1.0 + frequencies / 700.0)
f_min = 0.0
f_sp = 200.0 / 3
mels = (frequencies - f_min) / f_sp
min_log_hz = 1000.0
min_log_mel = (min_log_hz - f_min) / f_sp
logstep = np.log(6.4) / 27.0
if frequencies.ndim:
log_t = frequencies >= min_log_hz
mels[log_t] = min_log_mel + np.log(frequencies[log_t] / min_log_hz) / logstep
elif frequencies >= min_log_hz:
mels = min_log_mel + np.log(frequencies / min_log_hz) / logstep
return mels
def compute_cqt(y, sr=22050, hop_length=512, fmin=None, n_bins=84, bins_per_octave=12, tuning=0.0):
"""Compute Constant-Q Transform (CQT) spectrogram."""
def _relative_bandwidth(freqs):
bpo = np.empty_like(freqs)
logf = np.log2(freqs)
bpo[0] = 1.0 / (logf[1] - logf[0])
bpo[-1] = 1.0 / (logf[-1] - logf[-2])
bpo[1:-1] = 2.0 / (logf[2:] - logf[:-2])
return (2.0 ** (2.0 / bpo) - 1.0) / (2.0 ** (2.0 / bpo) + 1.0)
def _wavelet_lengths(freqs, sr, filter_scale, alpha):
Q = float(filter_scale) / alpha
return Q * sr / freqs # shape (n_bins,) floats
def _build_wavelet(freqs_oct, sr, filter_scale, alpha_oct):
lengths = _wavelet_lengths(freqs_oct, sr, filter_scale, alpha_oct)
filters = []
for ilen, freq in zip(lengths, freqs_oct):
t = np.arange(int(-ilen // 2), int(ilen // 2), dtype=float)
sig = (np.cos(t * 2 * np.pi * freq / sr)
+ 1j * np.sin(t * 2 * np.pi * freq / sr)).astype(np.complex64)
sig *= scipy.signal.get_window('hann', len(sig), fftbins=True)
l1 = np.sum(np.abs(sig))
tiny = np.finfo(np.float32).tiny
sig /= max(l1, tiny)
filters.append(sig)
max_len = max(lengths)
n_fft = int(2.0 ** np.ceil(np.log2(max_len)))
out = np.zeros((len(filters), n_fft), dtype=np.complex64)
for k, f in enumerate(filters):
lpad = int((n_fft - len(f)) // 2)
out[k, lpad: lpad + len(f)] = f
return out, lengths
def _resample_half(y):
ratio = 0.5
n_samples = int(np.ceil(len(y) * ratio))
# Kaiser-windowed FIR matches librosa/soxr more closely than scipy's default Hamming filter
L = 2
h = scipy.signal.firwin(160 * L + 1, 0.96 / L, window=('kaiser', 6.5))
y_hat = scipy.signal.resample_poly(y.astype(np.float32), 1, 2, window=h)
if len(y_hat) > n_samples:
y_hat = y_hat[:n_samples]
elif len(y_hat) < n_samples:
y_hat = np.pad(y_hat, (0, n_samples - len(y_hat)))
y_hat /= np.sqrt(ratio)
return y_hat.astype(np.float32)
def _sparsify_rows(x, quantile=0.01):
mags = np.abs(x)
norms = np.sum(mags, axis=1, keepdims=True)
norms = np.where(norms == 0, 1.0, norms)
mag_sort = np.sort(mags, axis=1)
cumulative_mag = np.cumsum(mag_sort / norms, axis=1)
threshold_idx = np.argmin(cumulative_mag < quantile, axis=1)
x_sparse = scipy.sparse.lil_matrix(x.shape, dtype=x.dtype)
for i, j in enumerate(threshold_idx):
idx = np.where(mags[i] >= mag_sort[i, j])
x_sparse[i, idx] = x[i, idx]
return x_sparse.tocsr()
if fmin is None:
fmin = 32.70319566257483 # C1 note frequency
fmin = fmin * (2.0 ** (tuning / bins_per_octave))
freqs = fmin * (2.0 ** (np.arange(n_bins) / bins_per_octave))
alpha = _relative_bandwidth(freqs)
lengths = _wavelet_lengths(freqs, float(sr), 1, alpha)
n_octaves = int(np.ceil(float(n_bins) / bins_per_octave))
n_filters = min(bins_per_octave, n_bins)
cqt_resp = []
my_y = y.astype(np.float32)
my_sr = float(sr)
my_hop = int(hop_length)
for i in range(n_octaves):
if i == 0:
sl = slice(-n_filters, None)
else:
sl = slice(-n_filters * (i + 1), -n_filters * i)
freqs_oct = freqs[sl]
alpha_oct = alpha[sl]
basis, basis_lengths = _build_wavelet(freqs_oct, my_sr, 1, alpha_oct)
n_fft_oct = basis.shape[1]
# Frequency-domain normalisation
basis = basis.astype(np.complex64)
basis *= basis_lengths[:, np.newaxis] / float(n_fft_oct)
fft_basis = scipy.fft.fft(basis, n=n_fft_oct, axis=1)[:, :(n_fft_oct // 2) + 1]
fft_basis = _sparsify_rows(fft_basis, quantile=0.01)
fft_basis = fft_basis * np.sqrt(sr / my_sr)
y_pad = np.pad(my_y, int(n_fft_oct // 2), mode='constant')
n_frames = 1 + (len(y_pad) - n_fft_oct) // my_hop
frames = np.lib.stride_tricks.as_strided(
y_pad,
shape=(n_fft_oct, n_frames),
strides=(y_pad.strides[0], y_pad.strides[0] * my_hop),
)
stft_result = scipy.fft.rfft(frames, axis=0)
cqt_resp.append(fft_basis.dot(stft_result))
if my_hop % 2 == 0:
my_hop //= 2
my_sr /= 2.0
my_y = _resample_half(my_y)
max_col = min(c.shape[-1] for c in cqt_resp)
cqt_out = np.empty((n_bins, max_col), dtype=np.complex64)
end = n_bins
for c_i in cqt_resp:
n_oct = c_i.shape[0]
if end < n_oct:
cqt_out[:end, :] = c_i[-end:, :max_col]
else:
cqt_out[end - n_oct:end, :] = c_i[:, :max_col]
end -= n_oct
cqt_out /= np.sqrt(lengths)[:, np.newaxis]
return np.abs(cqt_out).astype(np.float32)
def cq_to_chroma_mapping(n_input, bins_per_octave=12, n_chroma=12, fmin=None):
"""Map CQT bins to chroma bins."""
if fmin is None:
fmin = 32.70319566257483 # C1 note frequency
n_merge = bins_per_octave / n_chroma
cq_to_ch = np.repeat(np.eye(n_chroma), int(n_merge), axis=1)
cq_to_ch = np.roll(cq_to_ch, -int(n_merge // 2), axis=1)
n_octaves = int(np.ceil(n_input / bins_per_octave))
cq_to_ch = np.tile(cq_to_ch, n_octaves)[:, :n_input]
midi_0 = np.mod(12 * np.log2(fmin / 440.0) + 69, 12)
roll = int(np.round(midi_0 * (n_chroma / 12.0)))
cq_to_ch = np.roll(cq_to_ch, roll, axis=0)
return cq_to_ch.astype(np.float32)
def _parabolic_interpolation(S, axis=-2):
"""Compute parabolic interpolation shift for peak refinement."""
S_next = np.roll(S, -1, axis=axis)
S_prev = np.roll(S, 1, axis=axis)
a = S_next + S_prev - 2 * S
b = (S_next - S_prev) / 2.0
shifts = np.zeros_like(S)
valid = np.abs(b) < np.abs(a)
shifts[valid] = -b[valid] / a[valid]
if axis == -2 or axis == S.ndim - 2:
shifts[0, :] = 0
shifts[-1, :] = 0
elif axis == 0:
shifts[0, ...] = 0
shifts[-1, ...] = 0
return shifts
def _localmax(S, axis=-2):
"""Find local maxima along an axis."""
S_prev = np.roll(S, 1, axis=axis)
S_next = np.roll(S, -1, axis=axis)
local_max = (S > S_prev) & (S >= S_next)
if axis == -2 or axis == S.ndim - 2:
local_max[-1, :] = S[-1, :] > S[-2, :]
# First element is never a local max (strict inequality with previous)
local_max[0, :] = False
elif axis == 0:
local_max[-1, ...] = S[-1, ...] > S[-2, ...]
local_max[0, ...] = False
return local_max
def piptrack(y=None, sr=22050, S=None, n_fft=2048, hop_length=512,
fmin=150.0, fmax=4000.0, threshold=0.1):
"""Pitch tracking on thresholded parabolically-interpolated STFT."""
# Compute STFT if not provided
if S is None:
if y is None:
raise ValueError("Either y or S must be provided")
fft_window = scipy.signal.get_window('hann', n_fft, fftbins=True)
if len(fft_window) < n_fft:
lpad = int((n_fft - len(fft_window)) // 2)
fft_window = np.pad(fft_window, (lpad, int(n_fft - len(fft_window) - lpad)), mode='constant')
fft_window = fft_window.reshape((-1, 1))
y_pad = np.pad(y, int(n_fft // 2), mode='constant')
n_frames = 1 + (len(y_pad) - n_fft) // hop_length
frames = np.lib.stride_tricks.as_strided(
y_pad,
shape=(n_fft, n_frames),
strides=(y_pad.strides[0], y_pad.strides[0] * hop_length)
)
S = scipy.fft.rfft((fft_window * frames).astype(np.float32), axis=0)
S = np.abs(S)
fmin = max(fmin, 0)
fmax = min(fmax, float(sr) / 2)
fft_freqs = np.fft.rfftfreq(S.shape[0] * 2 - 2, 1.0 / sr)
if len(fft_freqs) > S.shape[0]:
fft_freqs = fft_freqs[:S.shape[0]]
shift = _parabolic_interpolation(S, axis=0)
avg = np.gradient(S, axis=0)
dskew = 0.5 * avg * shift
pitches = np.zeros_like(S)
mags = np.zeros_like(S)
freq_mask = (fmin <= fft_freqs) & (fft_freqs < fmax)
freq_mask = freq_mask.reshape(-1, 1)
ref_value = threshold * np.max(S, axis=0, keepdims=True)
local_max = _localmax(S * (S > ref_value), axis=0)
idx = np.nonzero(freq_mask & local_max)
pitches[idx] = (idx[0] + shift[idx]) * float(sr) / (S.shape[0] * 2 - 2)
mags[idx] = S[idx] + dskew[idx]
return pitches, mags
def hz_to_octs(frequencies, tuning=0.0, bins_per_octave=12):
"""Convert frequencies (Hz) to octave numbers."""
A440 = 440.0 * 2.0 ** (tuning / bins_per_octave)
octs = np.log2(np.asanyarray(frequencies) / (float(A440) / 16))
return octs
def pitch_tuning(frequencies, resolution=0.01, bins_per_octave=12):
"""Estimate tuning offset from a collection of pitches."""
frequencies = np.atleast_1d(frequencies)
frequencies = frequencies[frequencies > 0]
if not np.any(frequencies):
return 0.0
residual = np.mod(bins_per_octave * hz_to_octs(frequencies, tuning=0.0,
bins_per_octave=bins_per_octave), 1.0)
residual[residual >= 0.5] -= 1.0
bins = np.linspace(-0.5, 0.5, int(np.ceil(1.0 / resolution)) + 1)
counts, tuning = np.histogram(residual, bins)
tuning_est = tuning[np.argmax(counts)]
return tuning_est
def estimate_tuning(y, sr=22050, bins_per_octave=12):
"""Estimate global tuning deviation from 12-TET."""
n_fft = 2048
hop_length = 512
if len(y) < n_fft:
return 0.0
pitch, mag = piptrack(y=y, sr=sr, n_fft=n_fft, hop_length=hop_length,
fmin=150.0, fmax=4000.0, threshold=0.1)
pitch_mask = pitch > 0
if not pitch_mask.any():
return 0.0
threshold = np.median(mag[pitch_mask])
valid_pitches = pitch[(mag >= threshold) & pitch_mask]
if len(valid_pitches) == 0:
return 0.0
tuning = pitch_tuning(valid_pitches, resolution=0.01, bins_per_octave=bins_per_octave)
return float(tuning)
def compute_chroma_cens(y, sr=22050, hop_length=512, n_chroma=12,
n_octaves=7, bins_per_octave=36,
win_len_smooth=41, norm=2):
"""Compute Chroma Energy Normalized Statistics (CENS) features."""
tuning = estimate_tuning(y, sr, bins_per_octave=bins_per_octave)
fmin = 32.70319566257483 # C1 note frequency
n_bins = n_octaves * bins_per_octave
cqt_mag = compute_cqt(y, sr=sr, hop_length=hop_length,
fmin=fmin, n_bins=n_bins,
bins_per_octave=bins_per_octave,
tuning=tuning)
chroma_map = cq_to_chroma_mapping(n_bins, bins_per_octave=bins_per_octave,
n_chroma=n_chroma, fmin=fmin)
chroma = np.dot(chroma_map, cqt_mag)
threshold = np.finfo(chroma.dtype).tiny
chroma_sum = np.sum(np.abs(chroma), axis=0, keepdims=True)
chroma_sum = np.maximum(chroma_sum, threshold)
chroma = chroma / chroma_sum
quant_steps = [0.4, 0.2, 0.1, 0.05]
quant_weights = [0.25, 0.25, 0.25, 0.25]
chroma_quant = np.zeros_like(chroma)
for step, weight in zip(quant_steps, quant_weights):
chroma_quant += (chroma > step) * weight
if win_len_smooth is not None and win_len_smooth > 0:
win = scipy.signal.get_window('hann', win_len_smooth + 2, fftbins=False)
win /= np.sum(win)
win = win.reshape(1, -1)
chroma_smooth = scipy.ndimage.convolve(chroma_quant, win, mode='constant')
else:
chroma_smooth = chroma_quant
if norm == 2:
threshold = np.finfo(chroma_smooth.dtype).tiny
chroma_norm = np.sqrt(np.sum(chroma_smooth ** 2, axis=0, keepdims=True))
chroma_norm = np.maximum(chroma_norm, threshold)
chroma_smooth = chroma_smooth / chroma_norm
elif norm == np.inf:
threshold = np.finfo(chroma_smooth.dtype).tiny
chroma_norm = np.max(np.abs(chroma_smooth), axis=0, keepdims=True)
chroma_norm = np.maximum(chroma_norm, threshold)
chroma_smooth = chroma_smooth / chroma_norm
return chroma_smooth
def _create_mel_filterbank(sr, n_fft, n_mels=128, fmin=0.0, fmax=None):
"""Create mel-scale filterbank matrix."""
if fmax is None:
fmax = sr / 2.0
mel_basis = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=np.float32)
fftfreqs = np.fft.rfftfreq(n=n_fft, d=1.0 / sr)
min_mel = hz_to_mel(fmin)
max_mel = hz_to_mel(fmax)
mels = np.linspace(min_mel, max_mel, n_mels + 2)
mel_f = mel_to_hz(mels)
fdiff = np.diff(mel_f)
ramps = np.subtract.outer(mel_f, fftfreqs)
for i in range(n_mels):
lower = -ramps[i] / fdiff[i]
upper = ramps[i + 2] / fdiff[i + 1]
mel_basis[i] = np.maximum(0, np.minimum(lower, upper))
enorm = 2.0 / (mel_f[2:n_mels + 2] - mel_f[:n_mels])
mel_basis *= enorm[:, np.newaxis]
return mel_basis
def _compute_mel_spectrogram(data, sr, n_fft=2048, hop_length=512, n_mels=128):
"""Compute mel spectrogram from audio signal."""
fft_window = scipy.signal.get_window('hann', n_fft, fftbins=True)
if len(fft_window) < n_fft:
lpad = int((n_fft - len(fft_window)) // 2)
fft_window = np.pad(fft_window, (lpad, int(n_fft - len(fft_window) - lpad)), mode='constant')
fft_window = fft_window.reshape((-1, 1))
data_padded = np.pad(data, int(n_fft // 2), mode='constant')
n_frames = 1 + (len(data_padded) - n_fft) // hop_length
shape = (n_fft, n_frames)
strides = (data_padded.strides[0], data_padded.strides[0] * hop_length)
frames = np.lib.stride_tricks.as_strided(data_padded, shape=shape, strides=strides)
stft_result = scipy.fft.rfft(fft_window * frames, axis=0).astype(np.complex64)
power_spec = np.abs(stft_result) ** 2
mel_basis = _create_mel_filterbank(sr, n_fft, n_mels=n_mels, fmin=0.0, fmax=sr / 2.0)
mel_spec = np.dot(mel_basis, power_spec)
return mel_spec.astype(np.float32)
def quick_tempo_estimate(audio_np, sr, start_bpm=120.0, std_bpm=1.0, hop_length=512):
"""Estimate tempo using autocorrelation tempogram."""
if len(audio_np) < hop_length * 10:
logging.warning("Audio too short for tempo estimation, returning default BPM of 120.0")
return 120.0
n_fft = 2048
mel_S = _compute_mel_spectrogram(audio_np, sr, n_fft=n_fft, hop_length=hop_length, n_mels=128)
log_mel_S = 10.0 * np.log10(np.maximum(1e-10, mel_S))
lag = 1
S_diff = log_mel_S[:, lag:] - log_mel_S[:, :-lag]
S_onset = np.maximum(0.0, S_diff)
onset_env_pre = np.mean(S_onset, axis=0)
pad_width = lag + n_fft // (2 * hop_length)
onset_env = np.pad(onset_env_pre, (pad_width, 0), mode='constant')
onset_env = onset_env[:mel_S.shape[1]]
return estimate_tempo_from_onset(onset_env, sr, hop_length, start_bpm, std_bpm, max_tempo=320.0)
def estimate_tempo_from_onset(onset_env, sr, hop_length, start_bpm=120.0, std_bpm=1.0, max_tempo=320.0):
"""Estimate tempo from onset strength envelope using autocorrelation tempogram."""
if len(onset_env) < 20:
return 120.0
ac_size = 8.0
win_length = int(np.round(ac_size * sr / hop_length))
win_length = min(win_length, len(onset_env))
pad_width = win_length // 2
onset_padded = np.pad(onset_env, (pad_width, pad_width), mode='linear_ramp', end_values=(0, 0))
n_frames = len(onset_env)
shape = (win_length, n_frames)
strides = (onset_padded.strides[0], onset_padded.strides[0])
frames = np.lib.stride_tricks.as_strided(onset_padded, shape=shape, strides=strides)
hann_window = scipy.signal.get_window('hann', win_length, fftbins=True)
windowed_frames = frames * hann_window[:, np.newaxis]
tempogram = np.zeros((win_length, n_frames))
for i in range(n_frames):
frame = windowed_frames[:, i]
n_pad = scipy.fft.next_fast_len(2 * len(frame) - 1)
fft_result = scipy.fft.rfft(frame, n=n_pad)
powspec = np.abs(fft_result) ** 2
ac = scipy.fft.irfft(powspec, n=n_pad)
tempogram[:, i] = ac[:win_length]
ac_max = np.max(np.abs(tempogram), axis=0)
mask = ac_max > 0
tempogram[:, mask] /= ac_max[mask]
tempogram_mean = np.mean(tempogram, axis=1)
tempogram_mean = np.maximum(tempogram_mean, 0)
bpms = np.zeros(win_length, dtype=np.float64)
bpms[0] = np.inf
bpms[1:] = 60.0 * sr / (hop_length * np.arange(1.0, win_length))
logprior = -0.5 * ((np.log2(bpms) - np.log2(start_bpm)) / std_bpm) ** 2
if max_tempo is not None:
max_idx = int(np.argmax(bpms < max_tempo))
if max_idx > 0:
logprior[:max_idx] = -np.inf
weighted = np.log1p(1e6 * tempogram_mean) + logprior
best_idx = int(np.argmax(weighted[1:])) + 1
tempo = bpms[best_idx]
return tempo
def detect_onset_peaks(onset_env, sr=22050, hop_length=512, pre_max=0.03, post_max=0.0,
pre_avg=0.10, post_avg=0.10, wait=0.03, delta=0.07):
"""Detect onset peaks using peak picking algorithm."""
onset_normalized = onset_env - np.min(onset_env)
onset_max = np.max(onset_normalized)
if onset_max > 0:
onset_normalized = onset_normalized / onset_max
pre_max_frames = int(pre_max * sr / hop_length)
post_max_frames = int(post_max * sr / hop_length) + 1
pre_avg_frames = int(pre_avg * sr / hop_length)
post_avg_frames = int(post_avg * sr / hop_length) + 1
wait_frames = int(wait * sr / hop_length)
peaks = np.zeros(len(onset_normalized), dtype=bool)
peaks[0] = (onset_normalized[0] >= np.max(onset_normalized[:min(post_max_frames, len(onset_normalized))]))
peaks[0] &= (onset_normalized[0] >= np.mean(onset_normalized[:min(post_avg_frames, len(onset_normalized))]) + delta)
if peaks[0]:
n = wait_frames + 1
else:
n = 1
while n < len(onset_normalized):
maxn = np.max(onset_normalized[max(0, n - pre_max_frames):min(n + post_max_frames, len(onset_normalized))])
peaks[n] = (onset_normalized[n] == maxn)
if not peaks[n]:
n += 1
continue
avgn = np.mean(onset_normalized[max(0, n - pre_avg_frames):min(n + post_avg_frames, len(onset_normalized))])
peaks[n] &= (onset_normalized[n] >= avgn + delta)
if not peaks[n]:
n += 1
continue
n += wait_frames + 1
return np.flatnonzero(peaks).astype(np.int32)
def track_beats(onset_env, tempo, sr, hop_length, tightness=100, trim=True):
"""Track beats using dynamic programming."""
frame_rate = sr / hop_length
frames_per_beat = np.round(frame_rate * 60.0 / tempo)
if frames_per_beat <= 0 or len(onset_env) < 2:
return np.array([], dtype=np.int32)
onset_std = np.std(onset_env, ddof=1)
if onset_std > 0:
onset_normalized = onset_env / onset_std
else:
onset_normalized = onset_env
window_range = np.arange(-frames_per_beat, frames_per_beat + 1)
window = np.exp(-0.5 * (window_range * 32.0 / frames_per_beat) ** 2)
localscore = scipy.signal.convolve(onset_normalized, window, mode='same')
backlink = np.full(len(localscore), -1, dtype=np.int32)
cumscore = np.zeros(len(localscore), dtype=np.float64)
score_thresh = 0.01 * localscore.max()
first_beat = True
backlink[0] = -1
cumscore[0] = localscore[0]
fpb = int(frames_per_beat)
for i in range(1, len(localscore)):
score_i = localscore[i]
best_score = -np.inf
beat_location = -1
search_start = int(i - np.round(fpb / 2.0))
search_end = int(i - 2 * fpb - 1)
for loc in range(search_start, search_end, -1):
if loc < 0:
break
score = cumscore[loc] - tightness * (np.log(i - loc) - np.log(fpb)) ** 2
if score > best_score:
best_score = score
beat_location = loc
if beat_location >= 0:
cumscore[i] = score_i + best_score
else:
cumscore[i] = score_i
if first_beat and score_i < score_thresh:
backlink[i] = -1
else:
backlink[i] = beat_location
first_beat = False
local_max_mask = np.zeros(len(cumscore), dtype=bool)
local_max_mask[0] = False
for i in range(1, len(cumscore) - 1):
local_max_mask[i] = (cumscore[i] > cumscore[i-1]) and (cumscore[i] >= cumscore[i+1])
if len(cumscore) > 1:
local_max_mask[-1] = cumscore[-1] > cumscore[-2]
if np.any(local_max_mask):
median_max = np.median(cumscore[local_max_mask])
threshold = 0.5 * median_max
tail = -1
for i in range(len(cumscore) - 1, -1, -1):
if local_max_mask[i] and cumscore[i] >= threshold:
tail = i
break
else:
tail = len(cumscore) - 1
beats = np.zeros(len(localscore), dtype=bool)
n = tail
visited = set()
while n >= 0 and n not in visited:
beats[n] = True
visited.add(n)
n = backlink[n]
if trim and np.any(beats):
beat_positions = np.flatnonzero(beats)
beat_localscores = localscore[beat_positions]
w = np.hanning(5)
smooth_boe_full = np.convolve(beat_localscores, w)
smooth_boe = smooth_boe_full[len(w)//2 : len(localscore) + len(w)//2]
threshold = 0.5 * np.sqrt(np.mean(smooth_boe ** 2))
start_frame = 0
while start_frame < len(localscore) and localscore[start_frame] <= threshold:
beats[start_frame] = False
start_frame += 1
end_frame = len(localscore) - 1
while end_frame >= 0 and localscore[end_frame] <= threshold:
beats[end_frame] = False
end_frame -= 1
return np.flatnonzero(beats).astype(np.int32)
def compute_onset_envelope(mel_spec_db, n_fft=2048, hop_length=512):
"""Compute onset strength envelope from a log-mel spectrogram (dB)."""
lag = 1
onset_diff = mel_spec_db[:, lag:] - mel_spec_db[:, :-lag]
onset_diff = np.maximum(0.0, onset_diff)
envelope_pre_pad = np.mean(onset_diff, axis=0)
pad_width = lag + n_fft // (2 * hop_length)
envelope = np.pad(envelope_pre_pad, (pad_width, 0), mode='constant')
envelope = envelope[:mel_spec_db.shape[1]]
return envelope
def compute_mfcc(mel_spec_db, n_mfcc=20):
"""Compute MFCC features from a log-mel spectrogram (dB)."""
mfcc = scipy.fft.dct(mel_spec_db, axis=0, type=2, norm='ortho')[:n_mfcc].T
return mfcc.astype(np.float32)
def power_to_db(S, amin=1e-10, top_db=80.0, ref=1.0):
"""Convert a power spectrogram (amplitude squared) to decibel (dB) units"""
S = np.asarray(S)
log_spec = 10.0 * np.log10(np.maximum(amin, S))
log_spec -= 10.0 * np.log10(np.maximum(amin, ref))
if top_db is not None:
log_spec = np.maximum(log_spec, log_spec.max() - top_db)
return log_spec
class WanDancerEncodeAudio(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanDancerEncodeAudio",
category="conditioning/video_models",
inputs=[
io.Audio.Input("audio"),
io.Int.Input("video_frames", default=149, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Float.Input("audio_inject_scale", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="The scale for the audio features when injected into the video model."),
],
outputs=[
io.AudioEncoderOutput.Output(display_name="audio_encoder_output"),
io.String.Output(display_name="fps_string", tooltip="The calculated fps based on the audio length and the number of video frames. Used in the prompt."),
],
)
@classmethod
def execute(cls, video_frames, audio_inject_scale, audio) -> io.NodeOutput:
waveform = audio["waveform"][0]
sample_rate = audio["sample_rate"]
base_fps = 30
hop_length = 512
model_sr = 22050
n_fft = 2048
# start tempo from original audio (not the resampled one) to match the reference pipeline
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=False)
start_bpm = quick_tempo_estimate(waveform.squeeze().cpu().numpy(), sample_rate, hop_length=hop_length)
# resample to the sample rate used for feature extraction
resample_sr = base_fps * hop_length
waveform = torchaudio.functional.resample(waveform, sample_rate, resample_sr)
waveform_np = waveform.cpu().numpy().squeeze()
mel_spec = _compute_mel_spectrogram(waveform_np, model_sr, n_fft, hop_length, n_mels=128)
mel_spec_db = power_to_db(mel_spec, amin=1e-10, top_db=80.0, ref=1.0)
envelope = compute_onset_envelope(mel_spec_db, n_fft, hop_length)
mfcc = compute_mfcc(mel_spec_db, n_mfcc=20)
chroma = compute_chroma_cens(y=waveform_np, sr=model_sr, hop_length=hop_length).T
# detect peaks
peak_idxs = detect_onset_peaks(envelope, sr=model_sr, hop_length=hop_length)
peak_onehot = np.zeros_like(envelope, dtype=np.float32)
peak_onehot[peak_idxs] = 1.0
# detect beats
beat_tracking_tempo = estimate_tempo_from_onset(envelope, sr=model_sr, hop_length=hop_length, start_bpm=start_bpm)
beat_idxs = track_beats(envelope, beat_tracking_tempo, model_sr, hop_length, tightness=100, trim=True)
beat_onehot = np.zeros_like(envelope, dtype=np.float32)
beat_onehot[beat_idxs] = 1.0
audio_feature = np.concatenate(
[envelope[:, None], mfcc, chroma, peak_onehot[:, None], beat_onehot[:, None]],
axis=-1,
)
audio_feature = torch.from_numpy(audio_feature).unsqueeze(0).to(comfy.model_management.intermediate_device())
fps = float(base_fps / int(audio_feature.shape[1] / video_frames + 0.5))
audio_encoder_output = {
"audio_feature": audio_feature,
"fps": fps,
"audio_inject_scale": audio_inject_scale,
}
if int(fps + 0.5) != 30:
fps_string = " 帧率是{:.4f}".format(fps) # "frame rate is" in Chinese, as it was in the original pipeline
else:
fps_string = ", 帧率是30fps。" # to match the reference pipeline when the fps is 30
return io.NodeOutput(audio_encoder_output, fps_string)
class WanDancerVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanDancerVideo",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Int.Input("width", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=149, min=1, max=nodes.MAX_RESOLUTION, step=4, tooltip="The number of frames in the generated video. Should stay 149 for WanDancer."),
io.ClipVisionOutput.Input("clip_vision_output", optional=True, tooltip="The CLIP vision embeds for the first frame."),
io.ClipVisionOutput.Input("clip_vision_output_ref", optional=True, tooltip="The CLIP vision embeds for the reference image."),
io.Image.Input("start_image", optional=True, tooltip="The initial image(s) to be encoded, can be any number of frames."),
io.Mask.Input("mask", optional=True, tooltip="Image conditioning mask for the start image(s). White is kept, black is generated. Used for the local generations."),
io.AudioEncoderOutput.Input("audio_encoder_output", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent", tooltip="Empty latent."),
],
)
@classmethod
def execute(cls, positive, negative, vae, width, height, length, start_image=None, mask=None, clip_vision_output=None, clip_vision_output_ref=None, audio_encoder_output=None) -> io.NodeOutput:
latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
image = torch.zeros((length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
image[:start_image.shape[0]] = start_image
concat_latent_image = vae.encode(image[:, :, :, :3])
if mask is None:
concat_mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
concat_mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
else:
concat_mask = 1 - mask[:length].unsqueeze(0)
concat_mask = comfy.utils.common_upscale(concat_mask, concat_latent_image.shape[-2], concat_latent_image.shape[-1], "nearest-exact", "disabled")
concat_mask = torch.cat([torch.repeat_interleave(concat_mask[:, 0:1], repeats=4, dim=1), concat_mask[:, 1:]], dim=1)
concat_mask = concat_mask.view(1, concat_mask.shape[1] // 4, 4, concat_latent_image.shape[-2], concat_latent_image.shape[-1]).transpose(1, 2)
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": concat_mask})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": concat_mask})
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output, "clip_vision_output_ref": clip_vision_output_ref})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output, "clip_vision_output_ref": clip_vision_output_ref})
if audio_encoder_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"audio_embed": audio_encoder_output["audio_feature"], "fps": audio_encoder_output["fps"], "audio_inject_scale": audio_encoder_output.get("audio_inject_scale", 1.0)})
negative = node_helpers.conditioning_set_values(negative, {"audio_embed": audio_encoder_output["audio_feature"], "fps": audio_encoder_output["fps"], "audio_inject_scale": audio_encoder_output.get("audio_inject_scale", 1.0)})
out_latent = {}
out_latent["samples"] = latent
return io.NodeOutput(positive, negative, out_latent)
class WanDancerPadKeyframes(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanDancerPadKeyframes",
category="image/video",
inputs=[
io.Image.Input("images",),
io.Int.Input("segment_length", default=149, min=1, max=10000, tooltip="Length of this segment (usually 149 frames)"),
io.Int.Input("segment_index", default=0, min=0, max=100, tooltip="Which segment this is (0 for first, 1 for second, etc.)"),
io.Audio.Input("audio", tooltip="Audio to calculate total output frames from and extract segment audio."),
],
outputs=[
io.Image.Output(display_name="keyframes_sequence", tooltip="Padded keyframe sequence"),
io.Mask.Output(display_name="keyframes_mask", tooltip="Mask indicating valid frames"),
io.Audio.Output(display_name="audio_segment", tooltip="Audio segment for this video segment"),
],
)
@classmethod
def do_execute(cls, images, segment_length, segment_index, audio):
B, H, W, C = images.shape
fps = 30
# calculate total frames
audio_duration = audio["waveform"].shape[-1] / audio["sample_rate"]
segment_duration = segment_length / fps
buffer = 0.2
num_segments = int((audio_duration - buffer) / segment_duration) + 1 if audio_duration > buffer else 0
total_frames = num_segments * segment_length
mask = torch.zeros((segment_length, H, W), device=images.device, dtype=images.dtype)
keyframes = torch.zeros((segment_length, H, W, C), dtype=images.dtype, device=images.device)
# guard: with no audio or no images, nothing to place — leave keyframes/mask zeroed
if total_frames > 0 and B > 0:
frame_interval = float(total_frames) / B
seg_num = int(math.ceil(total_frames / segment_length))
is_last_segment = (segment_index == seg_num - 1)
positions = []
images_before_this_segment = 0
# count images consumed by previous segments
for seg_idx in range(segment_index):
end_idx = (total_frames - segment_length * seg_idx - 1) if seg_idx == seg_num - 1 else (segment_length - 1)
cnt = 0
while cnt * frame_interval < end_idx - frame_interval:
cnt += 1
images_before_this_segment += cnt
# positions for current segment
end_index = (total_frames - segment_length * segment_index - 1) if is_last_segment else (segment_length - 1)
cnt = 0
while cnt * frame_interval < end_index - frame_interval:
pos = int(math.ceil(frame_interval * cnt))
positions.append((pos, images_before_this_segment + cnt))
cnt += 1
positions.append((end_index, images_before_this_segment + cnt))
valid_positions = [(pos, idx) for pos, idx in positions if idx < B and pos < segment_length]
if valid_positions:
seg_positions, img_indices = zip(*valid_positions)
seg_positions = torch.tensor(seg_positions, dtype=torch.long, device=images.device)
img_indices = torch.tensor(img_indices, dtype=torch.long, device=images.device)
mask[seg_positions] = 1
keyframes[seg_positions] = images[img_indices]
# extract audio segment
segment_duration = segment_length / fps
start_time = segment_index * segment_duration
end_time = min(start_time + segment_duration, audio_duration)
sample_rate = audio["sample_rate"]
start_sample = int(start_time * sample_rate)
end_sample = int(end_time * sample_rate)
audio_segment_waveform = audio["waveform"][:, :, start_sample:end_sample]
audio_segment = {
"waveform": audio_segment_waveform,
"sample_rate": sample_rate
}
return keyframes, mask, audio_segment
@classmethod
def execute(cls, images, segment_length, segment_index, audio=None) -> io.NodeOutput:
return io.NodeOutput(*cls.do_execute(images, segment_length, segment_index, audio))
class WanDancerPadKeyframesList(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanDancerPadKeyframesList",
category="image/video",
inputs=[
io.Image.Input("images"),
io.Int.Input("segment_length", default=149, min=1, max=10000, tooltip="Length of each segment (usually 149 frames)"),
io.Int.Input("num_segments", default=1, min=1, max=100, tooltip="How many padded segments to emit as lists."),
io.Audio.Input("audio", tooltip="Audio to slice for each emitted segment."),
],
outputs=[
io.Image.Output(display_name="keyframes_sequence", tooltip="Padded keyframe sequences", is_output_list=True),
io.Mask.Output(display_name="keyframes_mask", tooltip="Masks indicating valid frames", is_output_list=True),
io.Audio.Output(display_name="audio_segment", tooltip="Audio segment for each video segment", is_output_list=True),
],
)
@classmethod
def execute(cls, images, segment_length, num_segments, audio=None) -> io.NodeOutput:
outputs = [WanDancerPadKeyframes.do_execute(images, segment_length, i, audio) for i in range(num_segments)]
keyframes, masks, audio_segments = zip(*outputs)
return io.NodeOutput(list(keyframes), list(masks), list(audio_segments))
class WanDancerExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
WanDancerVideo,
WanDancerEncodeAudio,
WanDancerPadKeyframes,
WanDancerPadKeyframesList,
]
async def comfy_entrypoint() -> WanDancerExtension:
return WanDancerExtension()

View File

@ -52,6 +52,8 @@ folder_names_and_paths["model_patches"] = ([os.path.join(models_dir, "model_patc
folder_names_and_paths["audio_encoders"] = ([os.path.join(models_dir, "audio_encoders")], supported_pt_extensions)
folder_names_and_paths["background_removal"] = ([os.path.join(models_dir, "background_removal")], supported_pt_extensions)
folder_names_and_paths["frame_interpolation"] = ([os.path.join(models_dir, "frame_interpolation")], supported_pt_extensions)
folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions)

View File

@ -2429,10 +2429,12 @@ async def init_builtin_extra_nodes():
"nodes_number_convert.py",
"nodes_painter.py",
"nodes_curve.py",
"nodes_bg_removal.py",
"nodes_rtdetr.py",
"nodes_frame_interpolation.py",
"nodes_sam3.py",
"nodes_void.py",
"nodes_wandancer.py",
]
import_failed = []

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,4 @@
comfyui-frontend-package==1.43.17
comfyui-frontend-package==1.43.18
comfyui-workflow-templates==0.9.72
comfyui-embedded-docs==0.4.4
torch

View File

@ -544,6 +544,7 @@ class PromptServer():
if os.path.isfile(file):
if 'preview' in request.rel_url.query:
with Image.open(file) as img:
img = ImageOps.exif_transpose(img)
preview_info = request.rel_url.query['preview'].split(';')
image_format = preview_info[0]
if image_format not in ['webp', 'jpeg'] or 'a' in request.rel_url.query.get('channel', ''):
@ -569,6 +570,7 @@ class PromptServer():
if channel == 'rgb':
with Image.open(file) as img:
img = ImageOps.exif_transpose(img)
if img.mode == "RGBA":
r, g, b, a = img.split()
new_img = Image.merge('RGB', (r, g, b))
@ -584,6 +586,7 @@ class PromptServer():
elif channel == 'a':
with Image.open(file) as img:
img = ImageOps.exif_transpose(img)
if img.mode == "RGBA":
_, _, _, a = img.split()
else:

View File

@ -124,9 +124,11 @@ class TestMathExpressionExecute:
with pytest.raises(Exception, match="not defined"):
self._exec("str(a)", a=42)
def test_boolean_result_raises(self):
with pytest.raises(ValueError, match="got bool"):
self._exec("a > b", a=5, b=3)
def test_boolean_result(self):
result = self._exec("a > b", a=5, b=3)
assert result[2] is True
result = self._exec("a > b", a=3, b=5)
assert result[2] is False
def test_empty_expression_raises(self):
with pytest.raises(ValueError, match="Expression cannot be empty"):