mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-14 11:07:24 +08:00
Compare commits
17 Commits
241d37314d
...
606bd36d90
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
606bd36d90 | ||
|
|
95f6652ef5 | ||
|
|
20f5e474da | ||
|
|
3200f28e3a | ||
|
|
a4b7e3beed | ||
|
|
7bbf1e8169 | ||
|
|
8b08bfdcbe | ||
|
|
4e823431cc | ||
|
|
66669b2ded | ||
|
|
65045730a6 | ||
|
|
87878f354f | ||
|
|
c5ecd231a2 | ||
|
|
9864f5ac86 | ||
|
|
05cd076bc1 | ||
|
|
d3c18c1636 | ||
|
|
bac6fc35fb | ||
|
|
56c74094c7 |
@ -431,9 +431,10 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Color adjust"
|
||||
"category": "Image Tools/Color adjust",
|
||||
"description": "Adjusts image brightness and contrast using a real-time GPU fragment shader."
|
||||
}
|
||||
]
|
||||
},
|
||||
"extra": {}
|
||||
}
|
||||
}
|
||||
@ -162,7 +162,7 @@
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "local-Canny to Image (Z-Image-Turbo)",
|
||||
"name": "Canny to Image (Z-Image-Turbo)",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
@ -1553,7 +1553,8 @@
|
||||
"VHS_MetadataImage": true,
|
||||
"VHS_KeepIntermediate": true
|
||||
},
|
||||
"category": "Image generation and editing/Canny to image"
|
||||
"category": "Image generation and editing/Canny to image",
|
||||
"description": "Generates an image from a Canny edge map using Z-Image-Turbo, with text conditioning."
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -1574,4 +1575,4 @@
|
||||
}
|
||||
},
|
||||
"version": 0.4
|
||||
}
|
||||
}
|
||||
@ -192,7 +192,7 @@
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "local-Canny to Video (LTX 2.0)",
|
||||
"name": "Canny to Video (LTX 2.0)",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
@ -3600,7 +3600,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Video generation and editing/Canny to video"
|
||||
"category": "Video generation and editing/Canny to video",
|
||||
"description": "Generates video from Canny edge maps using LTX-2, with optional synchronized audio."
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -3616,4 +3617,4 @@
|
||||
}
|
||||
},
|
||||
"version": 0.4
|
||||
}
|
||||
}
|
||||
@ -377,8 +377,9 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Color adjust"
|
||||
"category": "Image Tools/Color adjust",
|
||||
"description": "Adds lens-style chromatic aberration (color fringing) using a real-time GPU fragment shader."
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -596,7 +596,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Color adjust"
|
||||
"category": "Image Tools/Color adjust",
|
||||
"description": "Adjusts saturation, temperature, tint, and vibrance using a real-time GPU fragment shader."
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@ -1129,7 +1129,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Color adjust"
|
||||
"category": "Image Tools/Color adjust",
|
||||
"description": "Balances colors across shadows, midtones, and highlights using a real-time GPU fragment shader."
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@ -608,7 +608,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Color adjust"
|
||||
"category": "Image Tools/Color adjust",
|
||||
"description": "Fine-tunes tone and color with per-channel curve adjustments using a real-time GPU fragment shader."
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@ -1609,7 +1609,8 @@
|
||||
}
|
||||
],
|
||||
"extra": {},
|
||||
"category": "Image Tools/Crop"
|
||||
"category": "Image Tools/Crop",
|
||||
"description": "Splits an image into a 2×2 grid of four equal tiles."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -2946,7 +2946,8 @@
|
||||
}
|
||||
],
|
||||
"extra": {},
|
||||
"category": "Image Tools/Crop"
|
||||
"category": "Image Tools/Crop",
|
||||
"description": "Splits an image into a 3×3 grid of nine equal tiles."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -1579,7 +1579,8 @@
|
||||
"VHS_MetadataImage": true,
|
||||
"VHS_KeepIntermediate": true
|
||||
},
|
||||
"category": "Image generation and editing/Depth to image"
|
||||
"category": "Image generation and editing/Depth to image",
|
||||
"description": "Generates an image from a depth map using Z-Image-Turbo with text conditioning."
|
||||
},
|
||||
{
|
||||
"id": "458bdf3c-4b58-421c-af50-c9c663a4d74c",
|
||||
@ -2461,7 +2462,8 @@
|
||||
]
|
||||
},
|
||||
"workflowRendererVersion": "LG"
|
||||
}
|
||||
},
|
||||
"description": "Estimates a monocular depth map from an input image using the Lotus depth estimation model."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -4233,7 +4233,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Video generation and editing/Depth to video"
|
||||
"category": "Video generation and editing/Depth to video",
|
||||
"description": "Generates video from depth maps using LTX-2, with optional synchronized audio."
|
||||
},
|
||||
{
|
||||
"id": "38b60539-50a7-42f9-a5fe-bdeca26272e2",
|
||||
@ -5192,7 +5193,8 @@
|
||||
],
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
}
|
||||
},
|
||||
"description": "Estimates a monocular depth map from an input image using the Lotus depth estimation model."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -450,9 +450,10 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Blur"
|
||||
"category": "Image Tools/Blur",
|
||||
"description": "Applies bilateral (edge-preserving) blur to soften images while retaining detail."
|
||||
}
|
||||
]
|
||||
},
|
||||
"extra": {}
|
||||
}
|
||||
}
|
||||
@ -580,8 +580,9 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Color adjust"
|
||||
"category": "Image Tools/Color adjust",
|
||||
"description": "Adds procedural film grain texture for a cinematic look via GPU fragment shader."
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -3350,7 +3350,8 @@
|
||||
}
|
||||
],
|
||||
"extra": {},
|
||||
"category": "Video generation and editing/First-Last-Frame to Video"
|
||||
"category": "Video generation and editing/First-Last-Frame to Video",
|
||||
"description": "Generates a video interpolating between first and last keyframes using LTX-2.3."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -575,8 +575,9 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Color adjust"
|
||||
"category": "Image Tools/Color adjust",
|
||||
"description": "Adds a glow/bloom effect around bright image areas via GPU fragment shader."
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -752,8 +752,9 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Color adjust"
|
||||
"category": "Image Tools/Color adjust",
|
||||
"description": "Adjusts hue, saturation, and lightness of an image using a real-time GPU fragment shader."
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -374,7 +374,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Blur"
|
||||
"category": "Image Tools/Blur",
|
||||
"description": "Applies Gaussian, Box, or Radial blur to soften images and create stylized depth or motion effects."
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@ -310,7 +310,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Text generation/Image Captioning"
|
||||
"category": "Text generation/Image Captioning",
|
||||
"description": "Generates descriptive captions for images using Google's Gemini multimodal LLM."
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@ -315,8 +315,9 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Color adjust"
|
||||
"category": "Image Tools/Color adjust",
|
||||
"description": "Manipulates individual RGBA channels for masking, compositing, and channel effects."
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -2138,7 +2138,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Edit image"
|
||||
"category": "Image generation and editing/Edit image",
|
||||
"description": "Edits images via text instructions using FireRed Image Edit 1.1, a diffusion-based instruction-following editing model."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -1472,7 +1472,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Edit image"
|
||||
"category": "Image generation and editing/Edit image",
|
||||
"description": "Edits an input image via text instructions using FLUX.2 [klein] 4B."
|
||||
},
|
||||
{
|
||||
"id": "6007e698-2ebd-4917-84d8-299b35d7b7ab",
|
||||
@ -1821,7 +1822,8 @@
|
||||
],
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
}
|
||||
},
|
||||
"description": "Applies reference image conditioning for style/identity transfer (Flux.2 Klein 4B)."
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -1837,4 +1839,4 @@
|
||||
}
|
||||
},
|
||||
"version": 0.4
|
||||
}
|
||||
}
|
||||
|
||||
@ -1417,7 +1417,8 @@
|
||||
}
|
||||
],
|
||||
"extra": {},
|
||||
"category": "Image generation and editing/Edit image"
|
||||
"category": "Image generation and editing/Edit image",
|
||||
"description": "Edits images via text instructions using LongCat Image Edit, an instruction-following image editing diffusion model."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -132,7 +132,7 @@
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "local-Image Edit (Qwen 2511)",
|
||||
"name": "Image Edit (Qwen 2511)",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
@ -1468,7 +1468,8 @@
|
||||
"VHS_MetadataImage": true,
|
||||
"VHS_KeepIntermediate": true
|
||||
},
|
||||
"category": "Image generation and editing/Edit image"
|
||||
"category": "Image generation and editing/Edit image",
|
||||
"description": "Edits images via text instructions using Qwen-Image-Edit-2511 with improved character consistency and integrated LoRA."
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -1489,4 +1490,4 @@
|
||||
}
|
||||
},
|
||||
"version": 0.4
|
||||
}
|
||||
}
|
||||
@ -1188,7 +1188,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Inpaint image"
|
||||
"category": "Image generation and editing/Inpaint image",
|
||||
"description": "Inpaints masked image regions using Flux.1 fill [dev], Black Forest Labs' inpainting/outpainting model."
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -1202,4 +1203,4 @@
|
||||
},
|
||||
"ue_links": []
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1548,7 +1548,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Inpaint image"
|
||||
"category": "Image generation and editing/Inpaint image",
|
||||
"description": "Inpaints masked regions using Qwen-Image, extending its multilingual text rendering to inpainting tasks."
|
||||
},
|
||||
{
|
||||
"id": "56a1f603-fbd2-40ed-94ef-c9ecbd96aca8",
|
||||
@ -1907,7 +1908,8 @@
|
||||
],
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
}
|
||||
},
|
||||
"description": "Expands and softens mask edges to reduce visible seams after image processing."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -742,9 +742,10 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Color adjust"
|
||||
"category": "Image Tools/Color adjust",
|
||||
"description": "Adjusts black point, white point, and gamma for tonal range control via GPU shader."
|
||||
}
|
||||
]
|
||||
},
|
||||
"extra": {}
|
||||
}
|
||||
}
|
||||
@ -1919,7 +1919,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Outpaint image"
|
||||
"category": "Image generation and editing/Outpaint image",
|
||||
"description": "Outpaints beyond image boundaries using Qwen-Image's outpainting capabilities."
|
||||
},
|
||||
{
|
||||
"id": "f93c215e-c393-460e-9534-ed2c3d8a652e",
|
||||
@ -2278,7 +2279,8 @@
|
||||
],
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
}
|
||||
},
|
||||
"description": "Expands and softens mask edges to reduce visible seams after image processing."
|
||||
},
|
||||
{
|
||||
"id": "2a4b2cc0-db37-4302-a067-da392f38f06b",
|
||||
@ -2733,7 +2735,8 @@
|
||||
],
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
}
|
||||
},
|
||||
"description": "Scales both image and mask together while preserving alignment for editing workflows."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -141,7 +141,7 @@
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "local-Image Upscale(Z-image-Turbo)",
|
||||
"name": "Image Upscale (Z-image-Turbo)",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
@ -1302,7 +1302,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Enhance"
|
||||
"category": "Image generation and editing/Enhance",
|
||||
"description": "Upscales images to higher resolution using Z-Image-Turbo."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -99,7 +99,7 @@
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "local-Image to Depth Map (Lotus)",
|
||||
"name": "Image to Depth Map (Lotus)",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
@ -948,7 +948,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Depth to image"
|
||||
"category": "Image generation and editing/Depth to image",
|
||||
"description": "Estimates a monocular depth map from an input image using the Lotus depth estimation model."
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -964,4 +965,4 @@
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"version": 0.4
|
||||
}
|
||||
}
|
||||
@ -1586,7 +1586,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Image to layers"
|
||||
"category": "Image generation and editing/Image to layers",
|
||||
"description": "Decomposes an image into variable-resolution RGBA layers for independent editing using Qwen-Image-Layered."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -72,7 +72,7 @@
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "local-Image to Model (Hunyuan3d 2.1)",
|
||||
"name": "Image to 3D Model (Hunyuan3d 2.1)",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
@ -765,7 +765,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "3D/Image to 3D Model"
|
||||
"category": "3D/Image to 3D Model",
|
||||
"description": "Generates 3D mesh models from a single input image using Hunyuan3D 2.0/2.1."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -4223,7 +4223,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "Vue-corrected"
|
||||
},
|
||||
"category": "Video generation and editing/Image to video"
|
||||
"category": "Video generation and editing/Image to video",
|
||||
"description": "Generates video from a single input image using LTX-2.3."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -206,7 +206,7 @@
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "local-Image to Video (Wan 2.2)",
|
||||
"name": "Image to Video (Wan 2.2)",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
@ -2027,7 +2027,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Video generation and editing/Image to video"
|
||||
"category": "Video generation and editing/Image to video",
|
||||
"description": "Generates video from an image and text prompt using Wan 2.2, supporting T2V and I2V."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -134,7 +134,7 @@
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "local-Pose to Image (Z-Image-Turbo)",
|
||||
"name": "Pose to Image (Z-Image-Turbo)",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
@ -1298,7 +1298,8 @@
|
||||
"VHS_MetadataImage": true,
|
||||
"VHS_KeepIntermediate": true
|
||||
},
|
||||
"category": "Image generation and editing/Pose to image"
|
||||
"category": "Image generation and editing/Pose to image",
|
||||
"description": "Generates an image from pose keypoints using Z-Image-Turbo with text conditioning."
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -1319,4 +1320,4 @@
|
||||
}
|
||||
},
|
||||
"version": 0.4
|
||||
}
|
||||
}
|
||||
@ -3870,7 +3870,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Video generation and editing/Pose to video"
|
||||
"category": "Video generation and editing/Pose to video",
|
||||
"description": "Generates video from pose reference frames using LTX-2, with optional synchronized audio."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -270,9 +270,10 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Text generation/Prompt enhance"
|
||||
"category": "Text generation/Prompt enhance",
|
||||
"description": "Expands short text prompts into detailed descriptions using a text generation model for better generation quality."
|
||||
}
|
||||
]
|
||||
},
|
||||
"extra": {}
|
||||
}
|
||||
}
|
||||
@ -302,8 +302,9 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Sharpen"
|
||||
"category": "Image Tools/Sharpen",
|
||||
"description": "Sharpens image details using a GPU fragment shader for enhanced clarity."
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -222,7 +222,7 @@
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "local-Text to Audio (ACE-Step 1.5)",
|
||||
"name": "Text to Audio (ACE-Step 1.5)",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
@ -1502,7 +1502,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Audio/Music generation"
|
||||
"category": "Audio/Music generation",
|
||||
"description": "Generates audio/music from text prompts using ACE-Step 1.5, a diffusion-based audio generation model."
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -1518,4 +1519,4 @@
|
||||
}
|
||||
},
|
||||
"version": 0.4
|
||||
}
|
||||
}
|
||||
@ -1029,7 +1029,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Text to image"
|
||||
"category": "Image generation and editing/Text to image",
|
||||
"description": "Generates images from text prompts using Flux.1 [dev], Black Forest Labs' 12B diffusion model."
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -1043,4 +1044,4 @@
|
||||
},
|
||||
"ue_links": []
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1023,7 +1023,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Text to image"
|
||||
"category": "Image generation and editing/Text to image",
|
||||
"description": "Generates images from text prompts using Flux.1 Krea Dev, a Black Forest Labs × Krea collaboration variant."
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -1037,4 +1038,4 @@
|
||||
},
|
||||
"ue_links": []
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1104,7 +1104,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Text to image"
|
||||
"category": "Image generation and editing/Text to image",
|
||||
"description": "Generates images from text prompts using NetaYume Lumina, fine-tuned from Neta Lumina for anime-style and illustration generation."
|
||||
},
|
||||
{
|
||||
"id": "a07fdf06-1bda-4dac-bdbd-63ee8ebca1c9",
|
||||
@ -1458,11 +1459,12 @@
|
||||
],
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
}
|
||||
},
|
||||
"description": "Encodes a negative text prompt via CLIP for classifier-free guidance in anime-style generation (NetaYume Lumina)."
|
||||
}
|
||||
]
|
||||
},
|
||||
"extra": {
|
||||
"ue_links": []
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1941,7 +1941,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "Vue-corrected"
|
||||
},
|
||||
"category": "Image generation and editing/Text to image"
|
||||
"category": "Image generation and editing/Text to image",
|
||||
"description": "Generates images from text prompts using Qwen-Image-2512, with enhanced human realism and finer natural detail over the base version."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -1873,7 +1873,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Text to image"
|
||||
"category": "Image generation and editing/Text to image",
|
||||
"description": "Generates images from text prompts using Qwen-Image, Alibaba's 20B MMDiT model with excellent multilingual text rendering."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -149,7 +149,7 @@
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "local-Text to Image (Z-Image-Turbo)",
|
||||
"name": "Text to Image (Z-Image-Turbo)",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
@ -1054,7 +1054,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image generation and editing/Text to image"
|
||||
"category": "Image generation and editing/Text to image",
|
||||
"description": "Generates images from text prompts using Z-Image-Turbo, Alibaba's distilled 6B DiT model."
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -1075,4 +1076,4 @@
|
||||
}
|
||||
},
|
||||
"version": 0.4
|
||||
}
|
||||
}
|
||||
@ -4286,7 +4286,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "Vue-corrected"
|
||||
},
|
||||
"category": "Video generation and editing/Text to video"
|
||||
"category": "Video generation and editing/Text to video",
|
||||
"description": "Generates video from text prompts using LTX-2.3, Lightricks' video diffusion model."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -1572,7 +1572,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Video generation and editing/Text to video"
|
||||
"category": "Video generation and editing/Text to video",
|
||||
"description": "Generates video from text prompts using Wan2.2, Alibaba's diffusion video model."
|
||||
}
|
||||
]
|
||||
},
|
||||
@ -1586,4 +1587,4 @@
|
||||
"VHS_KeepIntermediate": true
|
||||
},
|
||||
"version": 0.4
|
||||
}
|
||||
}
|
||||
@ -434,8 +434,9 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Image Tools/Sharpen"
|
||||
"category": "Image Tools/Sharpen",
|
||||
"description": "Enhances edge contrast via unsharp masking for a sharper image appearance."
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -307,7 +307,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Text generation/Video Captioning"
|
||||
"category": "Text generation/Video Captioning",
|
||||
"description": "Generates descriptive captions for video input using Google's Gemini multimodal LLM."
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@ -165,7 +165,7 @@
|
||||
},
|
||||
"revision": 0,
|
||||
"config": {},
|
||||
"name": "local-Video Inpaint(Wan2.1 VACE)",
|
||||
"name": "Video Inpaint (Wan 2.1 VACE)",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
@ -2368,7 +2368,8 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Video generation and editing/Inpaint video"
|
||||
"category": "Video generation and editing/Inpaint video",
|
||||
"description": "Inpaints masked regions in video frames using Wan 2.1 VACE."
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
@ -584,8 +584,9 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Video Tools/Stitch videos"
|
||||
"category": "Video Tools/Stitch videos",
|
||||
"description": "Stitches multiple video clips into a single sequential video file."
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -412,9 +412,10 @@
|
||||
"extra": {
|
||||
"workflowRendererVersion": "LG"
|
||||
},
|
||||
"category": "Video generation and editing/Enhance video"
|
||||
"category": "Video generation and editing/Enhance video",
|
||||
"description": "Upscales video to 4× resolution using a GAN-based upscaling model."
|
||||
}
|
||||
]
|
||||
},
|
||||
"extra": {}
|
||||
}
|
||||
}
|
||||
7
comfy/background_removal/birefnet.json
Normal file
7
comfy/background_removal/birefnet.json
Normal file
@ -0,0 +1,7 @@
|
||||
{
|
||||
"model_type": "birefnet",
|
||||
"image_std": [1.0, 1.0, 1.0],
|
||||
"image_mean": [0.0, 0.0, 0.0],
|
||||
"image_size": 1024,
|
||||
"resize_to_original": true
|
||||
}
|
||||
689
comfy/background_removal/birefnet.py
Normal file
689
comfy/background_removal/birefnet.py
Normal file
@ -0,0 +1,689 @@
|
||||
import torch
|
||||
import comfy.ops
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
from functools import partial
|
||||
import torch.nn.functional as F
|
||||
from torchvision.ops import deform_conv2d
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
|
||||
CXT = [3072, 1536, 768, 384][1:][::-1][-3:]
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
|
||||
self.q = operations.Linear(dim, dim, bias=qkv_bias, device=device, dtype=dtype)
|
||||
self.kv = operations.Linear(dim, dim * 2, bias=qkv_bias, device=device, dtype=dtype)
|
||||
self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
optimized_attention = optimized_attention_for_device(x.device, mask=False, small_input=True)
|
||||
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
||||
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
k, v = kv[0], kv[1]
|
||||
|
||||
x = optimized_attention(
|
||||
q, k, v, heads=self.num_heads, skip_output_reshape=True, skip_reshape=True
|
||||
).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
|
||||
return x
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = operations.Linear(in_features, hidden_features, device=device, dtype=dtype)
|
||||
self.act = nn.GELU()
|
||||
self.fc2 = operations.Linear(hidden_features, out_features, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
|
||||
def window_partition(x, window_size):
|
||||
B, H, W, C = x.shape
|
||||
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
||||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
||||
return windows
|
||||
|
||||
|
||||
def window_reverse(windows, window_size, H, W):
|
||||
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
||||
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||||
return x
|
||||
|
||||
|
||||
class WindowAttention(nn.Module):
|
||||
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, device=None, dtype=None, operations=None):
|
||||
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.window_size = window_size # Wh, Ww
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads, device=device, dtype=dtype))
|
||||
|
||||
coords_h = torch.arange(self.window_size[0])
|
||||
coords_w = torch.arange(self.window_size[1])
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
relative_coords[:, :, 0] += self.window_size[0] - 1
|
||||
relative_coords[:, :, 1] += self.window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
||||
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||
self.register_buffer("relative_position_index", relative_position_index)
|
||||
|
||||
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, device=device, dtype=dtype)
|
||||
self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
B_, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
|
||||
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.long().view(-1)].view(
|
||||
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
|
||||
if mask is not None:
|
||||
nW = mask.shape[0]
|
||||
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
||||
attn = attn.view(-1, self.num_heads, N, N)
|
||||
attn = self.softmax(attn)
|
||||
else:
|
||||
attn = self.softmax(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class SwinTransformerBlock(nn.Module):
|
||||
def __init__(self, dim, num_heads, window_size=7, shift_size=0,
|
||||
mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
||||
norm_layer=nn.LayerNorm, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.window_size = window_size
|
||||
self.shift_size = shift_size
|
||||
self.mlp_ratio = mlp_ratio
|
||||
|
||||
self.norm1 = norm_layer(dim, device=device, dtype=dtype)
|
||||
self.attn = WindowAttention(
|
||||
dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
|
||||
qkv_bias=qkv_bias, qk_scale=qk_scale, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.norm2 = norm_layer(dim, device=device, dtype=dtype)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.H = None
|
||||
self.W = None
|
||||
|
||||
def forward(self, x, mask_matrix):
|
||||
B, L, C = x.shape
|
||||
H, W = self.H, self.W
|
||||
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
pad_l = pad_t = 0
|
||||
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
||||
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
||||
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
||||
_, Hp, Wp, _ = x.shape
|
||||
|
||||
if self.shift_size > 0:
|
||||
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
||||
attn_mask = mask_matrix
|
||||
else:
|
||||
shifted_x = x
|
||||
attn_mask = None
|
||||
|
||||
x_windows = window_partition(shifted_x, self.window_size)
|
||||
x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
|
||||
|
||||
attn_windows = self.attn(x_windows, mask=attn_mask)
|
||||
|
||||
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
||||
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
|
||||
|
||||
if self.shift_size > 0:
|
||||
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
||||
else:
|
||||
x = shifted_x
|
||||
|
||||
if pad_r > 0 or pad_b > 0:
|
||||
x = x[:, :H, :W, :].contiguous()
|
||||
|
||||
x = x.view(B, H * W, C)
|
||||
|
||||
x = shortcut + x
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class PatchMerging(nn.Module):
|
||||
def __init__(self, dim, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.reduction = operations.Linear(4 * dim, 2 * dim, bias=False, device=device, dtype=dtype)
|
||||
self.norm = operations.LayerNorm(4 * dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, H, W):
|
||||
B, L, C = x.shape
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
# padding
|
||||
pad_input = (H % 2 == 1) or (W % 2 == 1)
|
||||
if pad_input:
|
||||
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
|
||||
|
||||
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
||||
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
||||
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
||||
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
||||
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
||||
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
||||
|
||||
x = self.norm(x)
|
||||
x = self.reduction(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class BasicLayer(nn.Module):
|
||||
def __init__(self,
|
||||
dim,
|
||||
depth,
|
||||
num_heads,
|
||||
window_size=7,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
norm_layer=nn.LayerNorm,
|
||||
downsample=None,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.shift_size = window_size // 2
|
||||
self.depth = depth
|
||||
|
||||
# build blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
SwinTransformerBlock(
|
||||
dim=dim,
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
norm_layer=norm_layer,
|
||||
device=device, dtype=dtype, operations=operations)
|
||||
for i in range(depth)])
|
||||
|
||||
# patch merging layer
|
||||
if downsample is not None:
|
||||
self.downsample = downsample(dim=dim, device=device, dtype=dtype, operations=operations)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
def forward(self, x, H, W):
|
||||
Hp = int(np.ceil(H / self.window_size)) * self.window_size
|
||||
Wp = int(np.ceil(W / self.window_size)) * self.window_size
|
||||
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
|
||||
h_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size, -self.shift_size),
|
||||
slice(-self.shift_size, None))
|
||||
w_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size, -self.shift_size),
|
||||
slice(-self.shift_size, None))
|
||||
cnt = 0
|
||||
for h in h_slices:
|
||||
for w in w_slices:
|
||||
img_mask[:, h, w, :] = cnt
|
||||
cnt += 1
|
||||
|
||||
mask_windows = window_partition(img_mask, self.window_size)
|
||||
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
||||
|
||||
for blk in self.blocks:
|
||||
blk.H, blk.W = H, W
|
||||
x = blk(x, attn_mask)
|
||||
if self.downsample is not None:
|
||||
x_down = self.downsample(x, H, W)
|
||||
Wh, Ww = (H + 1) // 2, (W + 1) // 2
|
||||
return x, H, W, x_down, Wh, Ww
|
||||
else:
|
||||
return x, H, W, x, H, W
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
def __init__(self, patch_size=4, in_channels=3, embed_dim=96, norm_layer=None, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
patch_size = (patch_size, patch_size)
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.proj = operations.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype)
|
||||
if norm_layer is not None:
|
||||
self.norm = norm_layer(embed_dim, device=device, dtype=dtype)
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
def forward(self, x):
|
||||
_, _, H, W = x.size()
|
||||
if W % self.patch_size[1] != 0:
|
||||
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
|
||||
if H % self.patch_size[0] != 0:
|
||||
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
||||
|
||||
x = self.proj(x) # B C Wh Ww
|
||||
if self.norm is not None:
|
||||
Wh, Ww = x.size(2), x.size(3)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
x = self.norm(x)
|
||||
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SwinTransformer(nn.Module):
|
||||
def __init__(self,
|
||||
pretrain_img_size=224,
|
||||
patch_size=4,
|
||||
in_channels=3,
|
||||
embed_dim=96,
|
||||
depths=[2, 2, 6, 2],
|
||||
num_heads=[3, 6, 12, 24],
|
||||
window_size=7,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
patch_norm=True,
|
||||
out_indices=(0, 1, 2, 3),
|
||||
frozen_stages=-1,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
norm_layer = partial(operations.LayerNorm, device=device, dtype=dtype)
|
||||
self.pretrain_img_size = pretrain_img_size
|
||||
self.num_layers = len(depths)
|
||||
self.embed_dim = embed_dim
|
||||
self.patch_norm = patch_norm
|
||||
self.out_indices = out_indices
|
||||
self.frozen_stages = frozen_stages
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
norm_layer=norm_layer if self.patch_norm else None)
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
for i_layer in range(self.num_layers):
|
||||
layer = BasicLayer(
|
||||
dim=int(embed_dim * 2 ** i_layer),
|
||||
depth=depths[i_layer],
|
||||
num_heads=num_heads[i_layer],
|
||||
window_size=window_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
norm_layer=norm_layer,
|
||||
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
||||
device=device, dtype=dtype, operations=operations)
|
||||
self.layers.append(layer)
|
||||
|
||||
num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
|
||||
self.num_features = num_features
|
||||
|
||||
for i_layer in out_indices:
|
||||
layer = norm_layer(num_features[i_layer])
|
||||
layer_name = f'norm{i_layer}'
|
||||
self.add_module(layer_name, layer)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
x = self.patch_embed(x)
|
||||
|
||||
Wh, Ww = x.size(2), x.size(3)
|
||||
|
||||
outs = []
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
for i in range(self.num_layers):
|
||||
layer = self.layers[i]
|
||||
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
|
||||
|
||||
if i in self.out_indices:
|
||||
norm_layer = getattr(self, f'norm{i}')
|
||||
x_out = norm_layer(x_out)
|
||||
|
||||
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
|
||||
outs.append(out)
|
||||
|
||||
return tuple(outs)
|
||||
|
||||
class DeformableConv2d(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False, device=None, dtype=None, operations=None):
|
||||
|
||||
super(DeformableConv2d, self).__init__()
|
||||
|
||||
kernel_size = kernel_size if type(kernel_size) is tuple else (kernel_size, kernel_size)
|
||||
self.stride = stride if type(stride) is tuple else (stride, stride)
|
||||
self.padding = padding
|
||||
|
||||
self.offset_conv = operations.Conv2d(in_channels,
|
||||
2 * kernel_size[0] * kernel_size[1],
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=self.padding,
|
||||
bias=True, device=device, dtype=dtype)
|
||||
|
||||
self.modulator_conv = operations.Conv2d(in_channels,
|
||||
1 * kernel_size[0] * kernel_size[1],
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=self.padding,
|
||||
bias=True, device=device, dtype=dtype)
|
||||
|
||||
self.regular_conv = operations.Conv2d(in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=self.padding,
|
||||
bias=bias, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
offset = self.offset_conv(x)
|
||||
modulator = 2. * torch.sigmoid(self.modulator_conv(x))
|
||||
weight, bias, offload_info = comfy.ops.cast_bias_weight(self.regular_conv, x, offloadable=True)
|
||||
|
||||
x = deform_conv2d(
|
||||
input=x,
|
||||
offset=offset,
|
||||
weight=weight,
|
||||
bias=None,
|
||||
padding=self.padding,
|
||||
mask=modulator,
|
||||
stride=self.stride,
|
||||
)
|
||||
comfy.ops.uncast_bias_weight(self.regular_conv, weight, bias, offload_info)
|
||||
return x
|
||||
|
||||
class BasicDecBlk(nn.Module):
|
||||
def __init__(self, in_channels=64, out_channels=64, inter_channels=64, device=None, dtype=None, operations=None):
|
||||
super(BasicDecBlk, self).__init__()
|
||||
inter_channels = 64
|
||||
self.conv_in = operations.Conv2d(in_channels, inter_channels, 3, 1, padding=1, device=device, dtype=dtype)
|
||||
self.relu_in = nn.ReLU(inplace=True)
|
||||
self.dec_att = ASPPDeformable(in_channels=inter_channels, device=device, dtype=dtype, operations=operations)
|
||||
self.conv_out = operations.Conv2d(inter_channels, out_channels, 3, 1, padding=1, device=device, dtype=dtype)
|
||||
self.bn_in = operations.BatchNorm2d(inter_channels, device=device, dtype=dtype)
|
||||
self.bn_out = operations.BatchNorm2d(out_channels, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_in(x)
|
||||
x = self.bn_in(x)
|
||||
x = self.relu_in(x)
|
||||
x = self.dec_att(x)
|
||||
x = self.conv_out(x)
|
||||
x = self.bn_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class BasicLatBlk(nn.Module):
|
||||
def __init__(self, in_channels=64, out_channels=64, device=None, dtype=None, operations=None):
|
||||
super(BasicLatBlk, self).__init__()
|
||||
self.conv = operations.Conv2d(in_channels, out_channels, 1, 1, 0, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class _ASPPModuleDeformable(nn.Module):
|
||||
def __init__(self, in_channels, planes, kernel_size, padding, device, dtype, operations):
|
||||
super(_ASPPModuleDeformable, self).__init__()
|
||||
self.atrous_conv = DeformableConv2d(in_channels, planes, kernel_size=kernel_size,
|
||||
stride=1, padding=padding, bias=False, device=device, dtype=dtype, operations=operations)
|
||||
self.bn = operations.BatchNorm2d(planes, device=device, dtype=dtype)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.atrous_conv(x)
|
||||
x = self.bn(x)
|
||||
|
||||
return self.relu(x)
|
||||
|
||||
|
||||
class ASPPDeformable(nn.Module):
|
||||
def __init__(self, in_channels, out_channels=None, parallel_block_sizes=[1, 3, 7], device=None, dtype=None, operations=None):
|
||||
super(ASPPDeformable, self).__init__()
|
||||
self.down_scale = 1
|
||||
if out_channels is None:
|
||||
out_channels = in_channels
|
||||
self.in_channelster = 256 // self.down_scale
|
||||
|
||||
self.aspp1 = _ASPPModuleDeformable(in_channels, self.in_channelster, 1, padding=0, device=device, dtype=dtype, operations=operations)
|
||||
self.aspp_deforms = nn.ModuleList([
|
||||
_ASPPModuleDeformable(in_channels, self.in_channelster, conv_size, padding=int(conv_size//2), device=device, dtype=dtype, operations=operations)
|
||||
for conv_size in parallel_block_sizes
|
||||
])
|
||||
|
||||
self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
|
||||
operations.Conv2d(in_channels, self.in_channelster, 1, stride=1, bias=False, device=device, dtype=dtype),
|
||||
operations.BatchNorm2d(self.in_channelster, device=device, dtype=dtype),
|
||||
nn.ReLU(inplace=True))
|
||||
self.conv1 = operations.Conv2d(self.in_channelster * (2 + len(self.aspp_deforms)), out_channels, 1, bias=False, device=device, dtype=dtype)
|
||||
self.bn1 = operations.BatchNorm2d(out_channels, device=device, dtype=dtype)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.aspp1(x)
|
||||
x_aspp_deforms = [aspp_deform(x) for aspp_deform in self.aspp_deforms]
|
||||
x5 = self.global_avg_pool(x)
|
||||
x5 = F.interpolate(x5, size=x1.size()[2:], mode='bilinear', align_corners=True)
|
||||
x = torch.cat((x1, *x_aspp_deforms, x5), dim=1)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
|
||||
return x
|
||||
|
||||
class BiRefNet(nn.Module):
|
||||
def __init__(self, config=None, dtype=None, device=None, operations=None):
|
||||
super(BiRefNet, self).__init__()
|
||||
self.bb = SwinTransformer(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
channels = [1536, 768, 384, 192]
|
||||
channels = [c * 2 for c in channels]
|
||||
self.cxt = channels[1:][::-1][-3:]
|
||||
self.squeeze_module = nn.Sequential(*[
|
||||
BasicDecBlk(channels[0]+sum(self.cxt), channels[0], device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(1)
|
||||
])
|
||||
|
||||
self.decoder = Decoder(channels, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
def forward_enc(self, x):
|
||||
x1, x2, x3, x4 = self.bb(x)
|
||||
B, C, H, W = x.shape
|
||||
x1_, x2_, x3_, x4_ = self.bb(F.interpolate(x, size=(H//2, W//2), mode='bilinear', align_corners=True))
|
||||
x1 = torch.cat([x1, F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True)], dim=1)
|
||||
x2 = torch.cat([x2, F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True)], dim=1)
|
||||
x3 = torch.cat([x3, F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True)], dim=1)
|
||||
x4 = torch.cat([x4, F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True)], dim=1)
|
||||
x4 = torch.cat(
|
||||
(
|
||||
*[
|
||||
F.interpolate(x1, size=x4.shape[2:], mode='bilinear', align_corners=True),
|
||||
F.interpolate(x2, size=x4.shape[2:], mode='bilinear', align_corners=True),
|
||||
F.interpolate(x3, size=x4.shape[2:], mode='bilinear', align_corners=True),
|
||||
][-len(CXT):],
|
||||
x4
|
||||
),
|
||||
dim=1
|
||||
)
|
||||
return (x1, x2, x3, x4)
|
||||
|
||||
def forward_ori(self, x):
|
||||
(x1, x2, x3, x4) = self.forward_enc(x)
|
||||
x4 = self.squeeze_module(x4)
|
||||
features = [x, x1, x2, x3, x4]
|
||||
scaled_preds = self.decoder(features)
|
||||
return scaled_preds
|
||||
|
||||
def forward(self, pixel_values, intermediate_output=None):
|
||||
scaled_preds = self.forward_ori(pixel_values)
|
||||
return scaled_preds
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, channels, device, dtype, operations):
|
||||
super(Decoder, self).__init__()
|
||||
# factory kwargs
|
||||
fk = {"device":device, "dtype":dtype, "operations":operations}
|
||||
DecoderBlock = partial(BasicDecBlk, **fk)
|
||||
LateralBlock = partial(BasicLatBlk, **fk)
|
||||
DBlock = partial(SimpleConvs, **fk)
|
||||
|
||||
self.split = True
|
||||
N_dec_ipt = 64
|
||||
ic = 64
|
||||
ipt_cha_opt = 1
|
||||
self.ipt_blk5 = DBlock(2**10*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic)
|
||||
self.ipt_blk4 = DBlock(2**8*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic)
|
||||
self.ipt_blk3 = DBlock(2**6*3 if self.split else 3, [N_dec_ipt, channels[1]//8][ipt_cha_opt], inter_channels=ic)
|
||||
self.ipt_blk2 = DBlock(2**4*3 if self.split else 3, [N_dec_ipt, channels[2]//8][ipt_cha_opt], inter_channels=ic)
|
||||
self.ipt_blk1 = DBlock(2**0*3 if self.split else 3, [N_dec_ipt, channels[3]//8][ipt_cha_opt], inter_channels=ic)
|
||||
|
||||
self.decoder_block4 = DecoderBlock(channels[0]+([N_dec_ipt, channels[0]//8][ipt_cha_opt]), channels[1])
|
||||
self.decoder_block3 = DecoderBlock(channels[1]+([N_dec_ipt, channels[0]//8][ipt_cha_opt]), channels[2])
|
||||
self.decoder_block2 = DecoderBlock(channels[2]+([N_dec_ipt, channels[1]//8][ipt_cha_opt]), channels[3])
|
||||
self.decoder_block1 = DecoderBlock(channels[3]+([N_dec_ipt, channels[2]//8][ipt_cha_opt]), channels[3]//2)
|
||||
|
||||
fk = {"device":device, "dtype":dtype}
|
||||
|
||||
self.conv_out1 = nn.Sequential(operations.Conv2d(channels[3]//2+([N_dec_ipt, channels[3]//8][ipt_cha_opt]), 1, 1, 1, 0, **fk))
|
||||
|
||||
self.lateral_block4 = LateralBlock(channels[1], channels[1])
|
||||
self.lateral_block3 = LateralBlock(channels[2], channels[2])
|
||||
self.lateral_block2 = LateralBlock(channels[3], channels[3])
|
||||
|
||||
self.conv_ms_spvn_4 = operations.Conv2d(channels[1], 1, 1, 1, 0, **fk)
|
||||
self.conv_ms_spvn_3 = operations.Conv2d(channels[2], 1, 1, 1, 0, **fk)
|
||||
self.conv_ms_spvn_2 = operations.Conv2d(channels[3], 1, 1, 1, 0, **fk)
|
||||
|
||||
_N = 16
|
||||
|
||||
self.gdt_convs_4 = nn.Sequential(operations.Conv2d(channels[0] // 2, _N, 3, 1, 1, **fk), operations.BatchNorm2d(_N, **fk), nn.ReLU(inplace=True))
|
||||
self.gdt_convs_3 = nn.Sequential(operations.Conv2d(channels[1] // 2, _N, 3, 1, 1, **fk), operations.BatchNorm2d(_N, **fk), nn.ReLU(inplace=True))
|
||||
self.gdt_convs_2 = nn.Sequential(operations.Conv2d(channels[2] // 2, _N, 3, 1, 1, **fk), operations.BatchNorm2d(_N, **fk), nn.ReLU(inplace=True))
|
||||
|
||||
[setattr(self, f"gdt_convs_pred_{i}", nn.Sequential(operations.Conv2d(_N, 1, 1, 1, 0, **fk))) for i in range(2, 5)]
|
||||
[setattr(self, f"gdt_convs_attn_{i}", nn.Sequential(operations.Conv2d(_N, 1, 1, 1, 0, **fk))) for i in range(2, 5)]
|
||||
|
||||
def get_patches_batch(self, x, p):
|
||||
_size_h, _size_w = p.shape[2:]
|
||||
patches_batch = []
|
||||
for idx in range(x.shape[0]):
|
||||
columns_x = torch.split(x[idx], split_size_or_sections=_size_w, dim=-1)
|
||||
patches_x = []
|
||||
for column_x in columns_x:
|
||||
patches_x += [p.unsqueeze(0) for p in torch.split(column_x, split_size_or_sections=_size_h, dim=-2)]
|
||||
patch_sample = torch.cat(patches_x, dim=1)
|
||||
patches_batch.append(patch_sample)
|
||||
return torch.cat(patches_batch, dim=0)
|
||||
|
||||
def forward(self, features):
|
||||
x, x1, x2, x3, x4 = features
|
||||
|
||||
patches_batch = self.get_patches_batch(x, x4) if self.split else x
|
||||
x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1)
|
||||
p4 = self.decoder_block4(x4)
|
||||
p4_gdt = self.gdt_convs_4(p4)
|
||||
gdt_attn_4 = self.gdt_convs_attn_4(p4_gdt).sigmoid()
|
||||
p4 = p4 * gdt_attn_4
|
||||
_p4 = F.interpolate(p4, size=x3.shape[2:], mode='bilinear', align_corners=True)
|
||||
_p3 = _p4 + self.lateral_block4(x3)
|
||||
|
||||
patches_batch = self.get_patches_batch(x, _p3) if self.split else x
|
||||
_p3 = torch.cat((_p3, self.ipt_blk4(F.interpolate(patches_batch, size=x3.shape[2:], mode='bilinear', align_corners=True))), 1)
|
||||
p3 = self.decoder_block3(_p3)
|
||||
|
||||
p3_gdt = self.gdt_convs_3(p3)
|
||||
gdt_attn_3 = self.gdt_convs_attn_3(p3_gdt).sigmoid()
|
||||
p3 = p3 * gdt_attn_3
|
||||
_p3 = F.interpolate(p3, size=x2.shape[2:], mode='bilinear', align_corners=True)
|
||||
_p2 = _p3 + self.lateral_block3(x2)
|
||||
|
||||
patches_batch = self.get_patches_batch(x, _p2) if self.split else x
|
||||
_p2 = torch.cat((_p2, self.ipt_blk3(F.interpolate(patches_batch, size=x2.shape[2:], mode='bilinear', align_corners=True))), 1)
|
||||
p2 = self.decoder_block2(_p2)
|
||||
|
||||
p2_gdt = self.gdt_convs_2(p2)
|
||||
gdt_attn_2 = self.gdt_convs_attn_2(p2_gdt).sigmoid()
|
||||
p2 = p2 * gdt_attn_2
|
||||
|
||||
_p2 = F.interpolate(p2, size=x1.shape[2:], mode='bilinear', align_corners=True)
|
||||
_p1 = _p2 + self.lateral_block2(x1)
|
||||
|
||||
patches_batch = self.get_patches_batch(x, _p1) if self.split else x
|
||||
_p1 = torch.cat((_p1, self.ipt_blk2(F.interpolate(patches_batch, size=x1.shape[2:], mode='bilinear', align_corners=True))), 1)
|
||||
_p1 = self.decoder_block1(_p1)
|
||||
_p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True)
|
||||
|
||||
patches_batch = self.get_patches_batch(x, _p1) if self.split else x
|
||||
_p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1)
|
||||
p1_out = self.conv_out1(_p1)
|
||||
return p1_out
|
||||
|
||||
|
||||
class SimpleConvs(nn.Module):
|
||||
def __init__(
|
||||
self, in_channels: int, out_channels: int, inter_channels=64, device=None, dtype=None, operations=None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.conv1 = operations.Conv2d(in_channels, inter_channels, 3, 1, 1, device=device, dtype=dtype)
|
||||
self.conv_out = operations.Conv2d(inter_channels, out_channels, 3, 1, 1, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv_out(self.conv1(x))
|
||||
78
comfy/bg_removal_model.py
Normal file
78
comfy/bg_removal_model.py
Normal file
@ -0,0 +1,78 @@
|
||||
from .utils import load_torch_file
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
import logging
|
||||
|
||||
import comfy.ops
|
||||
import comfy.model_patcher
|
||||
import comfy.model_management
|
||||
import comfy.clip_model
|
||||
import comfy.background_removal.birefnet
|
||||
|
||||
BG_REMOVAL_MODELS = {
|
||||
"birefnet": comfy.background_removal.birefnet.BiRefNet
|
||||
}
|
||||
|
||||
class BackgroundRemovalModel():
|
||||
def __init__(self, json_config):
|
||||
with open(json_config) as f:
|
||||
config = json.load(f)
|
||||
|
||||
self.image_size = config.get("image_size", 1024)
|
||||
self.image_mean = config.get("image_mean", [0.0, 0.0, 0.0])
|
||||
self.image_std = config.get("image_std", [1.0, 1.0, 1.0])
|
||||
self.model_type = config.get("model_type", "birefnet")
|
||||
self.config = config.copy()
|
||||
model_class = BG_REMOVAL_MODELS.get(self.model_type)
|
||||
|
||||
self.load_device = comfy.model_management.text_encoder_device()
|
||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
|
||||
self.model.eval()
|
||||
|
||||
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
def encode_image(self, image):
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
H, W = image.shape[1], image.shape[2]
|
||||
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=False)
|
||||
out = self.model(pixel_values=pixel_values)
|
||||
out = torch.nn.functional.interpolate(out, size=(H, W), mode="bicubic", antialias=False)
|
||||
|
||||
mask = out.sigmoid().to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(0)
|
||||
if mask.shape[1] != 1:
|
||||
mask = mask.movedim(-1, 1)
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def load_background_removal_model(sd):
|
||||
if "bb.layers.1.blocks.0.attn.relative_position_index" in sd:
|
||||
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "background_removal"), "birefnet.json")
|
||||
else:
|
||||
return None
|
||||
|
||||
bg_model = BackgroundRemovalModel(json_config)
|
||||
m, u = bg_model.load_sd(sd)
|
||||
if len(m) > 0:
|
||||
logging.warning("missing background removal: {}".format(m))
|
||||
u = set(u)
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
if k not in u:
|
||||
sd.pop(k)
|
||||
return bg_model
|
||||
|
||||
def load(ckpt_path):
|
||||
sd = load_torch_file(ckpt_path)
|
||||
return load_background_removal_model(sd)
|
||||
@ -93,7 +93,7 @@ class Hook:
|
||||
self.hook_scope = hook_scope
|
||||
'''Scope of where this hook should apply in terms of the conds used in sampling run.'''
|
||||
self.custom_should_register = default_should_register
|
||||
'''Can be overriden with a compatible function to decide if this hook should be registered without the need to override .should_register'''
|
||||
'''Can be overridden with a compatible function to decide if this hook should be registered without the need to override .should_register'''
|
||||
|
||||
@property
|
||||
def strength(self):
|
||||
|
||||
@ -140,7 +140,7 @@ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
|
||||
alphas = alphacums[ddim_timesteps]
|
||||
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
|
||||
|
||||
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
||||
# according to the formula provided in https://arxiv.org/abs/2010.02502
|
||||
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
|
||||
if verbose:
|
||||
logging.info(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
|
||||
|
||||
@ -1135,7 +1135,7 @@ class AudioInjector_WAN(nn.Module):
|
||||
self.injector_adain_output_layers = nn.ModuleList(
|
||||
[operations.Linear(dim, dim, dtype=dtype, device=device) for _ in range(audio_injector_id)])
|
||||
|
||||
def forward(self, x, block_id, audio_emb, audio_emb_global, seq_len):
|
||||
def forward(self, x, block_id, audio_emb, audio_emb_global, seq_len, scale=1.0):
|
||||
audio_attn_id = self.injected_block_id.get(block_id, None)
|
||||
if audio_attn_id is None:
|
||||
return x
|
||||
@ -1148,12 +1148,15 @@ class AudioInjector_WAN(nn.Module):
|
||||
attn_hidden_states = adain_hidden_states
|
||||
else:
|
||||
attn_hidden_states = self.injector_pre_norm_feat[audio_attn_id](input_hidden_states)
|
||||
audio_emb = rearrange(audio_emb, "b t n c -> (b t) n c", t=num_frames)
|
||||
attn_audio_emb = audio_emb
|
||||
|
||||
if audio_emb.dim() == 3: # WanDancer case
|
||||
attn_audio_emb = rearrange(audio_emb, "b t c -> (b t) 1 c", t=num_frames)
|
||||
else: # S2V case
|
||||
attn_audio_emb = rearrange(audio_emb, "b t n c -> (b t) n c", t=num_frames)
|
||||
|
||||
residual_out = self.injector[audio_attn_id](x=attn_hidden_states, context=attn_audio_emb)
|
||||
residual_out = rearrange(
|
||||
residual_out, "(b t) n c -> b (t n) c", t=num_frames)
|
||||
x[:, :seq_len] = x[:, :seq_len] + residual_out
|
||||
residual_out = rearrange(residual_out, "(b t) n c -> b (t n) c", t=num_frames)
|
||||
x[:, :seq_len] = x[:, :seq_len] + residual_out * scale
|
||||
return x
|
||||
|
||||
|
||||
|
||||
251
comfy/ldm/wan/model_wandancer.py
Normal file
251
comfy/ldm/wan/model_wandancer.py
Normal file
@ -0,0 +1,251 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import comfy
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
|
||||
from .model import AudioInjector_WAN, WanModel, MLPProj, Head, sinusoidal_embedding_1d
|
||||
|
||||
|
||||
class MusicSelfAttention(nn.Module):
|
||||
def __init__(self, dim, num_heads, device=None, dtype=None, operations=None):
|
||||
assert dim % num_heads == 0
|
||||
super().__init__()
|
||||
self.embed_dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
|
||||
self.q_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||
self.k_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||
self.v_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||
self.out_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, freqs):
|
||||
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
||||
|
||||
q = self.q_proj(x).view(b, s, n, d)
|
||||
q = apply_rope1(q, freqs)
|
||||
|
||||
k = self.k_proj(x).view(b, s, n, d)
|
||||
k = apply_rope1(k, freqs)
|
||||
|
||||
x = optimized_attention(
|
||||
q.view(b, s, n * d),
|
||||
k.view(b, s, n * d),
|
||||
self.v_proj(x).view(b, s, n * d),
|
||||
heads=self.num_heads,
|
||||
)
|
||||
|
||||
return self.out_proj(x)
|
||||
|
||||
|
||||
class MusicEncoderLayer(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int, ffn_dim: int, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.self_attn = MusicSelfAttention(dim, num_heads, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.linear1 = operations.Linear(dim, ffn_dim, device=device, dtype=dtype)
|
||||
self.linear2 = operations.Linear(ffn_dim, dim, device=device, dtype=dtype)
|
||||
|
||||
self.norm1 = operations.LayerNorm(dim, device=device, dtype=dtype)
|
||||
self.norm2 = operations.LayerNorm(dim, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
||||
x = x + self.self_attn(self.norm1(x), freqs=freqs)
|
||||
x = x + self.linear2(torch.nn.functional.gelu(self.linear1(self.norm2(x)))) # ffn
|
||||
return x
|
||||
|
||||
|
||||
class WanDancerModel(WanModel):
|
||||
def __init__(self,
|
||||
model_type='wandancer',
|
||||
patch_size=(1, 2, 2),
|
||||
text_len=512,
|
||||
in_dim=16,
|
||||
dim=5120,
|
||||
ffn_dim=8192,
|
||||
freq_dim=256,
|
||||
text_dim=4096,
|
||||
out_dim=16,
|
||||
num_heads=16,
|
||||
num_layers=40,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=True,
|
||||
eps=1e-6,
|
||||
in_dim_ref_conv=None,
|
||||
image_model=None,
|
||||
device=None, dtype=None, operations=None,
|
||||
audio_inject_layers=[0, 4, 8, 12, 16, 20, 24, 27],
|
||||
music_dim = 256,
|
||||
music_heads = 4,
|
||||
music_feature_dim = 35,
|
||||
music_latent_dim = 256
|
||||
):
|
||||
|
||||
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim,
|
||||
num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model, in_dim_ref_conv=in_dim_ref_conv,
|
||||
device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.dtype = dtype
|
||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||
|
||||
self.patch_embedding_global = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=operation_settings.get("device"), dtype=torch.float32)
|
||||
self.img_emb_refimage = MLPProj(1280, dim, operation_settings=operation_settings)
|
||||
self.head_global = Head(dim, out_dim, patch_size, eps, operation_settings=operation_settings)
|
||||
|
||||
self.music_injector = AudioInjector_WAN(
|
||||
dim=self.dim,
|
||||
num_heads=self.num_heads,
|
||||
inject_layer=audio_inject_layers,
|
||||
root_net=self,
|
||||
enable_adain=False,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
self.music_projection = operations.Linear(music_feature_dim, music_latent_dim, device=device, dtype=dtype)
|
||||
self.music_encoder = nn.ModuleList([MusicEncoderLayer(dim=music_dim, num_heads=music_heads, ffn_dim=1024, device=device, dtype=dtype, operations=operations) for _ in range(2)])
|
||||
music_head_dim = music_dim // music_heads
|
||||
self.music_rope_embedder = EmbedND(dim=music_head_dim, theta=10000.0, axes_dim=[music_head_dim])
|
||||
|
||||
def forward_orig(self, x, t, context, clip_fea=None, clip_fea_ref=None, freqs=None, audio_embed=None, fps=30, audio_inject_scale=1.0, transformer_options={}, **kwargs):
|
||||
# embeddings
|
||||
if int(fps + 0.5) != 30:
|
||||
x = self.patch_embedding_global(x.float()).to(x.dtype)
|
||||
else:
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
|
||||
grid_sizes = x.shape[2:]
|
||||
latent_frames = grid_sizes[0]
|
||||
transformer_options["grid_sizes"] = grid_sizes
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
seq_len = x.size(1)
|
||||
|
||||
# time embeddings
|
||||
e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
|
||||
e = e.reshape(t.shape[0], -1, e.shape[-1])
|
||||
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
||||
|
||||
full_ref = None
|
||||
if self.ref_conv is not None: # model has the weight, but this wasn't used in the original pipeline
|
||||
full_ref = kwargs.get("reference_latent", None)
|
||||
if full_ref is not None:
|
||||
full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2)
|
||||
x = torch.concat((full_ref, x), dim=1)
|
||||
|
||||
# context
|
||||
context = self.text_embedding(context)
|
||||
|
||||
audio_emb = None
|
||||
if audio_embed is not None: # encode music feature,[1, frame_num, 35] -> [1, F*8, dim]
|
||||
music_feature = self.music_projection(audio_embed)
|
||||
|
||||
music_seq_len = music_feature.shape[1]
|
||||
music_ids = torch.arange(music_seq_len, device=music_feature.device, dtype=music_feature.dtype).reshape(1, -1, 1) # create 1D position IDs
|
||||
music_freqs = self.music_rope_embedder(music_ids).movedim(1, 2)
|
||||
|
||||
# apply encoder layers
|
||||
for layer in self.music_encoder:
|
||||
music_feature = layer(music_feature, music_freqs)
|
||||
|
||||
# interpolate
|
||||
audio_emb = torch.nn.functional.interpolate(music_feature.unsqueeze(1), size=(latent_frames * 8, self.dim), mode='bilinear').squeeze(1)
|
||||
|
||||
context_img_len = 0
|
||||
if self.img_emb is not None and clip_fea is not None:
|
||||
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
|
||||
context = torch.cat([context_clip, context], dim=1)
|
||||
context_img_len += clip_fea.shape[-2]
|
||||
if self.img_emb_refimage is not None and clip_fea_ref is not None:
|
||||
context_clip_ref = self.img_emb_refimage(clip_fea_ref)
|
||||
context = torch.cat([context_clip_ref, context], dim=1)
|
||||
context_img_len += clip_fea_ref.shape[-2]
|
||||
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||
if audio_emb is not None:
|
||||
x = self.music_injector(x, i, audio_emb, audio_emb_global=None, seq_len=seq_len, scale=audio_inject_scale)
|
||||
|
||||
# head
|
||||
if int(fps + 0.5) != 30:
|
||||
x = self.head_global(x, e)
|
||||
else:
|
||||
x = self.head(x, e)
|
||||
|
||||
if full_ref is not None:
|
||||
x = x[:, full_ref.shape[1]:]
|
||||
|
||||
# unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, clip_fea_ref=None, fps=30, audio_inject_scale=1.0, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
|
||||
t_len = t
|
||||
if time_dim_concat is not None:
|
||||
time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size)
|
||||
x = torch.cat([x, time_dim_concat], dim=2)
|
||||
t_len = x.shape[2]
|
||||
|
||||
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, fps=fps, transformer_options=transformer_options)
|
||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, clip_fea_ref=clip_fea_ref, freqs=freqs, fps=fps, audio_inject_scale=audio_inject_scale, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
|
||||
|
||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, fps=30, device=None, dtype=None, transformer_options={}):
|
||||
patch_size = self.patch_size
|
||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
||||
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
|
||||
|
||||
if steps_t is None:
|
||||
steps_t = t_len
|
||||
if steps_h is None:
|
||||
steps_h = h_len
|
||||
if steps_w is None:
|
||||
steps_w = w_len
|
||||
|
||||
h_start = 0
|
||||
w_start = 0
|
||||
rope_options = transformer_options.get("rope_options", None)
|
||||
if rope_options is not None:
|
||||
t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
|
||||
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
|
||||
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
|
||||
|
||||
t_start += rope_options.get("shift_t", 0.0)
|
||||
h_start += rope_options.get("shift_y", 0.0)
|
||||
w_start += rope_options.get("shift_x", 0.0)
|
||||
|
||||
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
|
||||
|
||||
if int(fps + 0.5) != 30:
|
||||
time_scale = 30.0 / fps # how many time units each frame represents relative to 30fps
|
||||
positions_new = torch.arange(steps_t, device=device, dtype=dtype) * time_scale + t_start
|
||||
total_frames_at_30fps = int(time_scale * steps_t + 0.5)
|
||||
positions_new[-1] = t_start + (total_frames_at_30fps - 1)
|
||||
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + positions_new.reshape(-1, 1, 1)
|
||||
else:
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
|
||||
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
|
||||
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
|
||||
|
||||
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
||||
return freqs
|
||||
@ -43,6 +43,7 @@ import comfy.ldm.lumina.model
|
||||
import comfy.ldm.wan.model
|
||||
import comfy.ldm.wan.model_animate
|
||||
import comfy.ldm.wan.ar_model
|
||||
import comfy.ldm.wan.model_wandancer
|
||||
import comfy.ldm.hunyuan3d.model
|
||||
import comfy.ldm.hidream.model
|
||||
import comfy.ldm.chroma.model
|
||||
@ -1599,6 +1600,30 @@ class WAN21_SCAIL(WAN21):
|
||||
|
||||
return out
|
||||
|
||||
class WAN22_WanDancer(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=True, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_wandancer.WanDancerModel)
|
||||
self.image_to_video = image_to_video
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
audio_embed = kwargs.get("audio_embed", None)
|
||||
if audio_embed is not None:
|
||||
out['audio_embed'] = comfy.conds.CONDRegular(audio_embed)
|
||||
|
||||
clip_vision_output_ref = kwargs.get("clip_vision_output_ref", None)
|
||||
if clip_vision_output_ref is not None:
|
||||
out['clip_fea_ref'] = comfy.conds.CONDRegular(clip_vision_output_ref.penultimate_hidden_states)
|
||||
|
||||
fps = kwargs.get("fps", None)
|
||||
if fps is not None:
|
||||
out['fps'] = comfy.conds.CONDRegular(torch.FloatTensor([fps]))
|
||||
|
||||
audio_inject_scale = kwargs.get("audio_inject_scale", None)
|
||||
if audio_inject_scale is not None:
|
||||
out['audio_inject_scale'] = comfy.conds.CONDRegular(torch.FloatTensor([audio_inject_scale]))
|
||||
return out
|
||||
|
||||
class Hunyuan3Dv2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
|
||||
|
||||
@ -572,6 +572,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["model_type"] = "animate"
|
||||
elif '{}patch_embedding_pose.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "scail"
|
||||
elif '{}patch_embedding_global.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "wandancer"
|
||||
else:
|
||||
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "i2v"
|
||||
|
||||
22
comfy/ops.py
22
comfy/ops.py
@ -562,6 +562,25 @@ class disable_weight_init:
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class BatchNorm2d(torch.nn.BatchNorm2d, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
running_mean = self.running_mean.to(device=input.device, dtype=weight.dtype) if self.running_mean is not None else None
|
||||
running_var = self.running_var.to(device=input.device, dtype=weight.dtype) if self.running_var is not None else None
|
||||
x = torch.nn.functional.batch_norm(input, running_mean, running_var, weight, bias, self.training, self.momentum, self.eps)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
run_every_op()
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||
else:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
@ -749,6 +768,9 @@ class manual_cast(disable_weight_init):
|
||||
class Conv3d(disable_weight_init.Conv3d):
|
||||
comfy_cast_weights = True
|
||||
|
||||
class BatchNorm2d(disable_weight_init.BatchNorm2d):
|
||||
comfy_cast_weights = True
|
||||
|
||||
class GroupNorm(disable_weight_init.GroupNorm):
|
||||
comfy_cast_weights = True
|
||||
|
||||
|
||||
@ -1313,6 +1313,37 @@ class WAN21_SCAIL(WAN21_T2V):
|
||||
out = model_base.WAN21_SCAIL(self, image_to_video=False, device=device)
|
||||
return out
|
||||
|
||||
class WAN22_WanDancer(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
"model_type": "wandancer",
|
||||
"in_dim": 36,
|
||||
}
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
self.memory_usage_factor = 1.8
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.WAN22_WanDancer(self, image_to_video=True, device=device)
|
||||
return out
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
out_sd = {}
|
||||
for k in list(state_dict.keys()):
|
||||
# split music_encoder in_proj into q_proj, k_proj, v_proj
|
||||
if "music_encoder" in k and "self_attn.in_proj" in k:
|
||||
suffix = "weight" if k.endswith("weight") else "bias"
|
||||
tensor = state_dict[k]
|
||||
d = tensor.shape[0] // 3
|
||||
prefix = k.replace(f"in_proj_{suffix}", "")
|
||||
out_sd[f"{prefix}q_proj.{suffix}"] = tensor[:d]
|
||||
out_sd[f"{prefix}k_proj.{suffix}"] = tensor[d:2*d]
|
||||
out_sd[f"{prefix}v_proj.{suffix}"] = tensor[2*d:]
|
||||
else:
|
||||
out_sd[k] = state_dict[k]
|
||||
return out_sd
|
||||
|
||||
class Hunyuan3Dv2(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "hunyuan3d2",
|
||||
@ -1982,6 +2013,7 @@ models = [
|
||||
WAN22_Animate,
|
||||
WAN21_FlowRVS,
|
||||
WAN21_SCAIL,
|
||||
WAN22_WanDancer,
|
||||
Hunyuan3Dv2mini,
|
||||
Hunyuan3Dv2,
|
||||
Hunyuan3Dv2_1,
|
||||
|
||||
@ -1390,7 +1390,7 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
|
||||
k_out = "{}.weight_scale".format(layer)
|
||||
|
||||
if layer is not None:
|
||||
layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints
|
||||
layer_conf = {"format": "float8_e4m3fn"}
|
||||
if full_precision_matrix_mult:
|
||||
layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult
|
||||
layers[layer] = layer_conf
|
||||
|
||||
@ -17,6 +17,7 @@ if TYPE_CHECKING:
|
||||
from spandrel import ImageModelDescriptor
|
||||
from comfy.clip_vision import ClipVisionModel
|
||||
from comfy.clip_vision import Output as ClipVisionOutput_
|
||||
from comfy.bg_removal_model import BackgroundRemovalModel
|
||||
from comfy.controlnet import ControlNet
|
||||
from comfy.hooks import HookGroup, HookKeyframeGroup
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
@ -614,6 +615,11 @@ class Model(ComfyTypeIO):
|
||||
if TYPE_CHECKING:
|
||||
Type = ModelPatcher
|
||||
|
||||
@comfytype(io_type="BACKGROUND_REMOVAL")
|
||||
class BackgroundRemoval(ComfyTypeIO):
|
||||
if TYPE_CHECKING:
|
||||
Type = BackgroundRemovalModel
|
||||
|
||||
@comfytype(io_type="CLIP_VISION")
|
||||
class ClipVision(ComfyTypeIO):
|
||||
if TYPE_CHECKING:
|
||||
@ -2257,6 +2263,7 @@ __all__ = [
|
||||
"ModelPatch",
|
||||
"ClipVision",
|
||||
"ClipVisionOutput",
|
||||
"BackgroundRemoval",
|
||||
"AudioEncoder",
|
||||
"AudioEncoderOutput",
|
||||
"StyleModel",
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
from __future__ import annotations
|
||||
from enum import Enum
|
||||
from typing import Optional, List, Dict, Any, Union
|
||||
from typing import Optional, Any
|
||||
|
||||
from pydantic import BaseModel, Field, RootModel
|
||||
|
||||
|
||||
class TripoModelVersion(str, Enum):
|
||||
v3_1_20260211 = 'v3.1-20260211'
|
||||
v3_0_20250812 = 'v3.0-20250812'
|
||||
v2_5_20250123 = 'v2.5-20250123'
|
||||
v2_0_20240919 = 'v2.0-20240919'
|
||||
@ -142,7 +143,7 @@ class TripoFileEmptyReference(BaseModel):
|
||||
pass
|
||||
|
||||
class TripoFileReference(RootModel):
|
||||
root: Union[TripoFileTokenReference, TripoUrlReference, TripoObjectReference, TripoFileEmptyReference]
|
||||
root: TripoFileTokenReference | TripoUrlReference | TripoObjectReference | TripoFileEmptyReference
|
||||
|
||||
class TripoGetStsTokenRequest(BaseModel):
|
||||
format: str = Field(..., description='The format of the image')
|
||||
@ -183,7 +184,7 @@ class TripoImageToModelRequest(BaseModel):
|
||||
|
||||
class TripoMultiviewToModelRequest(BaseModel):
|
||||
type: TripoTaskType = TripoTaskType.MULTIVIEW_TO_MODEL
|
||||
files: List[TripoFileReference] = Field(..., description='The file references to convert to a model')
|
||||
files: list[TripoFileReference] = Field(..., description='The file references to convert to a model')
|
||||
model_version: Optional[TripoModelVersion] = Field(None, description='The model version to use for generation')
|
||||
orthographic_projection: Optional[bool] = Field(False, description='Whether to use orthographic projection')
|
||||
face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to')
|
||||
@ -251,27 +252,13 @@ class TripoConvertModelRequest(BaseModel):
|
||||
with_animation: Optional[bool] = Field(None, description='Whether to include animations')
|
||||
pack_uv: Optional[bool] = Field(None, description='Whether to pack the UVs')
|
||||
bake: Optional[bool] = Field(None, description='Whether to bake the model')
|
||||
part_names: Optional[List[str]] = Field(None, description='The names of the parts to include')
|
||||
part_names: Optional[list[str]] = Field(None, description='The names of the parts to include')
|
||||
fbx_preset: Optional[TripoFbxPreset] = Field(None, description='The preset for the FBX export')
|
||||
export_vertex_colors: Optional[bool] = Field(None, description='Whether to export the vertex colors')
|
||||
export_orientation: Optional[TripoOrientation] = Field(None, description='The orientation for the export')
|
||||
animate_in_place: Optional[bool] = Field(None, description='Whether to animate in place')
|
||||
|
||||
|
||||
class TripoTaskRequest(RootModel):
|
||||
root: Union[
|
||||
TripoTextToModelRequest,
|
||||
TripoImageToModelRequest,
|
||||
TripoMultiviewToModelRequest,
|
||||
TripoTextureModelRequest,
|
||||
TripoRefineModelRequest,
|
||||
TripoAnimatePrerigcheckRequest,
|
||||
TripoAnimateRigRequest,
|
||||
TripoAnimateRetargetRequest,
|
||||
TripoStylizeModelRequest,
|
||||
TripoConvertModelRequest
|
||||
]
|
||||
|
||||
class TripoTaskOutput(BaseModel):
|
||||
model: Optional[str] = Field(None, description='URL to the model')
|
||||
base_model: Optional[str] = Field(None, description='URL to the base model')
|
||||
@ -283,12 +270,13 @@ class TripoTask(BaseModel):
|
||||
task_id: str = Field(..., description='The task ID')
|
||||
type: Optional[str] = Field(None, description='The type of task')
|
||||
status: Optional[TripoTaskStatus] = Field(None, description='The status of the task')
|
||||
input: Optional[Dict[str, Any]] = Field(None, description='The input parameters for the task')
|
||||
input: Optional[dict[str, Any]] = Field(None, description='The input parameters for the task')
|
||||
output: Optional[TripoTaskOutput] = Field(None, description='The output of the task')
|
||||
progress: Optional[int] = Field(None, description='The progress of the task', ge=0, le=100)
|
||||
create_time: Optional[int] = Field(None, description='The creation time of the task')
|
||||
running_left_time: Optional[int] = Field(None, description='The estimated time left for the task')
|
||||
queue_position: Optional[int] = Field(None, description='The position in the queue')
|
||||
consumed_credit: int | None = Field(None)
|
||||
|
||||
class TripoTaskResponse(BaseModel):
|
||||
code: int = Field(0, description='The response code')
|
||||
@ -296,7 +284,7 @@ class TripoTaskResponse(BaseModel):
|
||||
|
||||
class TripoGeneralResponse(BaseModel):
|
||||
code: int = Field(0, description='The response code')
|
||||
data: Dict[str, str] = Field(..., description='The task ID data')
|
||||
data: dict[str, str] = Field(..., description='The task ID data')
|
||||
|
||||
class TripoBalanceData(BaseModel):
|
||||
balance: float = Field(..., description='The account balance')
|
||||
|
||||
@ -1271,7 +1271,7 @@ PRICE_BADGE_VIDEO = IO.PriceBadge(
|
||||
)
|
||||
|
||||
|
||||
def _seedance2_text_inputs(resolutions: list[str]):
|
||||
def _seedance2_text_inputs(resolutions: list[str], default_ratio: str = "16:9"):
|
||||
return [
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@ -1287,6 +1287,7 @@ def _seedance2_text_inputs(resolutions: list[str]):
|
||||
IO.Combo.Input(
|
||||
"ratio",
|
||||
options=["16:9", "4:3", "1:1", "3:4", "9:16", "21:9", "adaptive"],
|
||||
default=default_ratio,
|
||||
tooltip="Aspect ratio of the output video.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
@ -1420,8 +1421,14 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs(["480p", "720p", "1080p"])),
|
||||
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs(["480p", "720p"])),
|
||||
IO.DynamicCombo.Option(
|
||||
"Seedance 2.0",
|
||||
_seedance2_text_inputs(["480p", "720p", "1080p"], default_ratio="adaptive"),
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"Seedance 2.0 Fast",
|
||||
_seedance2_text_inputs(["480p", "720p"], default_ratio="adaptive"),
|
||||
),
|
||||
],
|
||||
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
||||
),
|
||||
@ -1588,9 +1595,9 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
||||
|
||||
|
||||
def _seedance2_reference_inputs(resolutions: list[str]):
|
||||
def _seedance2_reference_inputs(resolutions: list[str], default_ratio: str = "16:9"):
|
||||
return [
|
||||
*_seedance2_text_inputs(resolutions),
|
||||
*_seedance2_text_inputs(resolutions, default_ratio=default_ratio),
|
||||
IO.Autogrow.Input(
|
||||
"reference_images",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
@ -1668,8 +1675,14 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_reference_inputs(["480p", "720p", "1080p"])),
|
||||
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_reference_inputs(["480p", "720p"])),
|
||||
IO.DynamicCombo.Option(
|
||||
"Seedance 2.0",
|
||||
_seedance2_reference_inputs(["480p", "720p", "1080p"], default_ratio="adaptive"),
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"Seedance 2.0 Fast",
|
||||
_seedance2_reference_inputs(["480p", "720p"], default_ratio="adaptive"),
|
||||
),
|
||||
],
|
||||
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
||||
),
|
||||
|
||||
@ -60,6 +60,7 @@ async def poll_until_finished(
|
||||
],
|
||||
status_extractor=lambda x: x.data.status,
|
||||
progress_extractor=lambda x: x.data.progress,
|
||||
price_extractor=lambda x: x.data.consumed_credit * 0.01 if x.data.consumed_credit else None,
|
||||
estimated_duration=average_duration,
|
||||
)
|
||||
if response_poll.data.status == TripoTaskStatus.SUCCESS:
|
||||
@ -113,7 +114,6 @@ class TripoTextToModelNode(IO.ComfyNode):
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=[
|
||||
"model_version",
|
||||
"style",
|
||||
"texture",
|
||||
"pbr",
|
||||
"quad",
|
||||
@ -124,20 +124,17 @@ class TripoTextToModelNode(IO.ComfyNode):
|
||||
expr="""
|
||||
(
|
||||
$isV14 := $contains(widgets.model_version,"v1.4");
|
||||
$style := widgets.style;
|
||||
$hasStyle := ($style != "" and $style != "none");
|
||||
$isV3OrLater := $contains(widgets.model_version,"v3.");
|
||||
$withTexture := widgets.texture or widgets.pbr;
|
||||
$isHdTexture := (widgets.texture_quality = "detailed");
|
||||
$isDetailedGeometry := (widgets.geometry_quality = "detailed");
|
||||
$baseCredits :=
|
||||
$isV14 ? 20 : ($withTexture ? 20 : 10);
|
||||
$credits :=
|
||||
$baseCredits
|
||||
+ ($hasStyle ? 5 : 0)
|
||||
$credits := $isV14 ? 20 : (
|
||||
($withTexture ? 20 : 10)
|
||||
+ (widgets.quad ? 5 : 0)
|
||||
+ ($isHdTexture ? 10 : 0)
|
||||
+ ($isDetailedGeometry ? 20 : 0);
|
||||
{"type":"usd","usd": $round($credits * 0.01, 2)}
|
||||
+ (($isDetailedGeometry and $isV3OrLater) ? 20 : 0)
|
||||
);
|
||||
{"type":"usd","usd": $round($credits * 0.01, 2), "format": {"approximate": true}}
|
||||
)
|
||||
""",
|
||||
),
|
||||
@ -239,7 +236,6 @@ class TripoImageToModelNode(IO.ComfyNode):
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=[
|
||||
"model_version",
|
||||
"style",
|
||||
"texture",
|
||||
"pbr",
|
||||
"quad",
|
||||
@ -250,20 +246,17 @@ class TripoImageToModelNode(IO.ComfyNode):
|
||||
expr="""
|
||||
(
|
||||
$isV14 := $contains(widgets.model_version,"v1.4");
|
||||
$style := widgets.style;
|
||||
$hasStyle := ($style != "" and $style != "none");
|
||||
$isV3OrLater := $contains(widgets.model_version,"v3.");
|
||||
$withTexture := widgets.texture or widgets.pbr;
|
||||
$isHdTexture := (widgets.texture_quality = "detailed");
|
||||
$isDetailedGeometry := (widgets.geometry_quality = "detailed");
|
||||
$baseCredits :=
|
||||
$isV14 ? 30 : ($withTexture ? 30 : 20);
|
||||
$credits :=
|
||||
$baseCredits
|
||||
+ ($hasStyle ? 5 : 0)
|
||||
$credits := $isV14 ? 30 : (
|
||||
($withTexture ? 30 : 20)
|
||||
+ (widgets.quad ? 5 : 0)
|
||||
+ ($isHdTexture ? 10 : 0)
|
||||
+ ($isDetailedGeometry ? 20 : 0);
|
||||
{"type":"usd","usd": $round($credits * 0.01, 2)}
|
||||
+ (($isDetailedGeometry and $isV3OrLater) ? 20 : 0)
|
||||
);
|
||||
{"type":"usd","usd": $round($credits * 0.01, 2), "format": {"approximate": true}}
|
||||
)
|
||||
""",
|
||||
),
|
||||
@ -358,7 +351,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
|
||||
"texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True, advanced=True
|
||||
),
|
||||
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True, advanced=True),
|
||||
IO.Boolean.Input("quad", default=False, optional=True, advanced=True),
|
||||
IO.Boolean.Input("quad", default=False, optional=True, advanced=True, tooltip="This parameter is deprecated and does nothing."),
|
||||
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True),
|
||||
],
|
||||
outputs=[
|
||||
@ -379,7 +372,6 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
|
||||
"model_version",
|
||||
"texture",
|
||||
"pbr",
|
||||
"quad",
|
||||
"texture_quality",
|
||||
"geometry_quality",
|
||||
],
|
||||
@ -387,17 +379,16 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
|
||||
expr="""
|
||||
(
|
||||
$isV14 := $contains(widgets.model_version,"v1.4");
|
||||
$isV3OrLater := $contains(widgets.model_version,"v3.");
|
||||
$withTexture := widgets.texture or widgets.pbr;
|
||||
$isHdTexture := (widgets.texture_quality = "detailed");
|
||||
$isDetailedGeometry := (widgets.geometry_quality = "detailed");
|
||||
$baseCredits :=
|
||||
$isV14 ? 30 : ($withTexture ? 30 : 20);
|
||||
$credits :=
|
||||
$baseCredits
|
||||
+ (widgets.quad ? 5 : 0)
|
||||
$credits := $isV14 ? 30 : (
|
||||
($withTexture ? 30 : 20)
|
||||
+ ($isHdTexture ? 10 : 0)
|
||||
+ ($isDetailedGeometry ? 20 : 0);
|
||||
{"type":"usd","usd": $round($credits * 0.01, 2)}
|
||||
+ (($isDetailedGeometry and $isV3OrLater) ? 20 : 0)
|
||||
);
|
||||
{"type":"usd","usd": $round($credits * 0.01, 2), "format": {"approximate": true}}
|
||||
)
|
||||
""",
|
||||
),
|
||||
@ -457,7 +448,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
|
||||
geometry_quality=geometry_quality,
|
||||
texture_alignment=texture_alignment,
|
||||
face_limit=face_limit if face_limit != -1 else None,
|
||||
quad=quad,
|
||||
quad=None,
|
||||
),
|
||||
)
|
||||
return await poll_until_finished(cls, response, average_duration=80)
|
||||
@ -498,7 +489,7 @@ class TripoTextureNode(IO.ComfyNode):
|
||||
expr="""
|
||||
(
|
||||
$tq := widgets.texture_quality;
|
||||
{"type":"usd","usd": ($contains($tq,"detailed") ? 0.2 : 0.1)}
|
||||
{"type":"usd","usd": ($contains($tq,"detailed") ? 0.2 : 0.1), "format": {"approximate": true}}
|
||||
)
|
||||
""",
|
||||
),
|
||||
@ -555,7 +546,7 @@ class TripoRefineNode(IO.ComfyNode):
|
||||
is_api_node=True,
|
||||
is_output_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.3}""",
|
||||
expr="""{"type":"usd","usd":0.3, "format": {"approximate": true}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@ -592,7 +583,7 @@ class TripoRigNode(IO.ComfyNode):
|
||||
is_api_node=True,
|
||||
is_output_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.25}""",
|
||||
expr="""{"type":"usd","usd":0.25, "format": {"approximate": true}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@ -652,7 +643,7 @@ class TripoRetargetNode(IO.ComfyNode):
|
||||
is_api_node=True,
|
||||
is_output_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.1}""",
|
||||
expr="""{"type":"usd","usd":0.1, "format": {"approximate": true}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@ -761,19 +752,10 @@ class TripoConversionNode(IO.ComfyNode):
|
||||
"face_limit",
|
||||
"texture_size",
|
||||
"texture_format",
|
||||
"force_symmetry",
|
||||
"flatten_bottom",
|
||||
"flatten_bottom_threshold",
|
||||
"pivot_to_center_bottom",
|
||||
"scale_factor",
|
||||
"with_animation",
|
||||
"pack_uv",
|
||||
"bake",
|
||||
"part_names",
|
||||
"fbx_preset",
|
||||
"export_vertex_colors",
|
||||
"export_orientation",
|
||||
"animate_in_place",
|
||||
],
|
||||
),
|
||||
expr="""
|
||||
@ -783,28 +765,16 @@ class TripoConversionNode(IO.ComfyNode):
|
||||
$flatThresh := (widgets.flatten_bottom_threshold != null) ? widgets.flatten_bottom_threshold : 0;
|
||||
$scale := (widgets.scale_factor != null) ? widgets.scale_factor : 1;
|
||||
$texFmt := (widgets.texture_format != "" ? widgets.texture_format : "jpeg");
|
||||
$part := widgets.part_names;
|
||||
$fbx := (widgets.fbx_preset != "" ? widgets.fbx_preset : "blender");
|
||||
$orient := (widgets.export_orientation != "" ? widgets.export_orientation : "default");
|
||||
$advanced :=
|
||||
widgets.quad or
|
||||
widgets.force_symmetry or
|
||||
widgets.flatten_bottom or
|
||||
widgets.pivot_to_center_bottom or
|
||||
widgets.with_animation or
|
||||
widgets.pack_uv or
|
||||
widgets.bake or
|
||||
widgets.export_vertex_colors or
|
||||
widgets.animate_in_place or
|
||||
($face != -1) or
|
||||
($texSize != 4096) or
|
||||
($flatThresh != 0) or
|
||||
($scale != 1) or
|
||||
($texFmt != "jpeg") or
|
||||
($part != "") or
|
||||
($fbx != "blender") or
|
||||
($orient != "default");
|
||||
{"type":"usd","usd": ($advanced ? 0.1 : 0.05)}
|
||||
($texFmt != "jpeg");
|
||||
{"type":"usd","usd": ($advanced ? 0.1 : 0.05), "format": {"approximate": true}}
|
||||
)
|
||||
""",
|
||||
),
|
||||
|
||||
@ -488,10 +488,30 @@ async def _diagnose_connectivity() -> dict[str, bool]:
|
||||
"api_accessible": False,
|
||||
}
|
||||
timeout = aiohttp.ClientTimeout(total=5.0)
|
||||
|
||||
# Probe Google and Baidu in parallel: Google is blocked by the GFW in mainland China, so a Baidu probe is required
|
||||
# to correctly detect that Chinese users with working internet do have working internet.
|
||||
internet_probe_urls = ("https://www.google.com", "https://www.baidu.com")
|
||||
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
with contextlib.suppress(ClientError, OSError):
|
||||
async with session.get("https://www.google.com") as resp:
|
||||
results["internet_accessible"] = resp.status < 500
|
||||
async def _probe(url: str) -> bool:
|
||||
try:
|
||||
async with session.get(url) as resp:
|
||||
return resp.status < 500
|
||||
except (ClientError, OSError, asyncio.TimeoutError):
|
||||
return False
|
||||
|
||||
probe_tasks = [asyncio.create_task(_probe(u)) for u in internet_probe_urls]
|
||||
try:
|
||||
for fut in asyncio.as_completed(probe_tasks):
|
||||
if await fut:
|
||||
results["internet_accessible"] = True
|
||||
break
|
||||
finally:
|
||||
for t in probe_tasks:
|
||||
if not t.done():
|
||||
t.cancel()
|
||||
await asyncio.gather(*probe_tasks, return_exceptions=True)
|
||||
if not results["internet_accessible"]:
|
||||
return results
|
||||
|
||||
|
||||
60
comfy_extras/nodes_bg_removal.py
Normal file
60
comfy_extras/nodes_bg_removal.py
Normal file
@ -0,0 +1,60 @@
|
||||
import folder_paths
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
from comfy.bg_removal_model import load
|
||||
|
||||
|
||||
class LoadBackgroundRemovalModel(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
files = folder_paths.get_filename_list("background_removal")
|
||||
return IO.Schema(
|
||||
node_id="LoadBackgroundRemovalModel",
|
||||
display_name="Load Background Removal Model",
|
||||
category="loaders",
|
||||
inputs=[
|
||||
IO.Combo.Input("bg_removal_name", options=sorted(files), tooltip="The model used to remove backgrounds from images"),
|
||||
],
|
||||
outputs=[
|
||||
IO.BackgroundRemoval.Output("bg_model")
|
||||
]
|
||||
)
|
||||
@classmethod
|
||||
def execute(cls, bg_removal_name):
|
||||
path = folder_paths.get_full_path_or_raise("background_removal", bg_removal_name)
|
||||
bg = load(path)
|
||||
if bg is None:
|
||||
raise RuntimeError("ERROR: background model file is invalid and does not contain a valid background removal model.")
|
||||
return IO.NodeOutput(bg)
|
||||
|
||||
class RemoveBackground(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="RemoveBackground",
|
||||
display_name="Remove Background",
|
||||
category="image/background removal",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="Input image to remove the background from"),
|
||||
IO.BackgroundRemoval.Input("bg_removal_model", tooltip="Background removal model used to generate the mask")
|
||||
],
|
||||
outputs=[
|
||||
IO.Mask.Output("mask", tooltip="Generated foreground mask")
|
||||
]
|
||||
)
|
||||
@classmethod
|
||||
def execute(cls, image, bg_removal_model):
|
||||
mask = bg_removal_model.encode_image(image)
|
||||
return IO.NodeOutput(mask)
|
||||
|
||||
class BackgroundRemovalExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
LoadBackgroundRemovalModel,
|
||||
RemoveBackground
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> BackgroundRemovalExtension:
|
||||
return BackgroundRemovalExtension()
|
||||
@ -203,7 +203,7 @@ class JoinImageWithAlpha(io.ComfyNode):
|
||||
@classmethod
|
||||
def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput:
|
||||
batch_size = max(len(image), len(alpha))
|
||||
alpha = 1.0 - resize_mask(alpha, image.shape[1:])
|
||||
alpha = 1.0 - resize_mask(alpha.to(image), image.shape[1:])
|
||||
alpha = comfy.utils.repeat_to_batch_size(alpha, batch_size)
|
||||
image = comfy.utils.repeat_to_batch_size(image, batch_size)
|
||||
return io.NodeOutput(torch.cat((image[..., :3], alpha.unsqueeze(-1)), dim=-1))
|
||||
|
||||
@ -102,7 +102,7 @@ class FluxDisableGuidance(io.ComfyNode):
|
||||
append = execute # TODO: remove
|
||||
|
||||
|
||||
PREFERED_KONTEXT_RESOLUTIONS = [
|
||||
PREFERRED_KONTEXT_RESOLUTIONS = [
|
||||
(672, 1568),
|
||||
(688, 1504),
|
||||
(720, 1456),
|
||||
@ -143,7 +143,7 @@ class FluxKontextImageScale(io.ComfyNode):
|
||||
width = image.shape[2]
|
||||
height = image.shape[1]
|
||||
aspect_ratio = width / height
|
||||
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
|
||||
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS)
|
||||
image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1)
|
||||
return io.NodeOutput(image)
|
||||
|
||||
|
||||
@ -106,12 +106,12 @@ class LTXVImgToVideoInplace(io.ComfyNode):
|
||||
if bypass:
|
||||
return (latent,)
|
||||
|
||||
samples = latent["samples"]
|
||||
samples = latent["samples"].clone()
|
||||
_, height_scale_factor, width_scale_factor = (
|
||||
vae.downscale_index_formula
|
||||
)
|
||||
|
||||
batch, _, latent_frames, latent_height, latent_width = samples.shape
|
||||
_, _, _, latent_height, latent_width = samples.shape
|
||||
width = latent_width * width_scale_factor
|
||||
height = latent_height * height_scale_factor
|
||||
|
||||
@ -124,11 +124,7 @@ class LTXVImgToVideoInplace(io.ComfyNode):
|
||||
|
||||
samples[:, :, :t.shape[2]] = t
|
||||
|
||||
conditioning_latent_frames_mask = torch.ones(
|
||||
(batch, 1, latent_frames, 1, 1),
|
||||
dtype=torch.float32,
|
||||
device=samples.device,
|
||||
)
|
||||
conditioning_latent_frames_mask = get_noise_mask(latent)
|
||||
conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength
|
||||
|
||||
return io.NodeOutput({"samples": samples, "noise_mask": conditioning_latent_frames_mask})
|
||||
@ -236,7 +232,7 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
def encode(cls, vae, latent_width, latent_height, images, scale_factors):
|
||||
time_scale_factor, width_scale_factor, height_scale_factor = scale_factors
|
||||
images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1]
|
||||
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="disabled").movedim(1, -1)
|
||||
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="center").movedim(1, -1)
|
||||
encode_pixels = pixels[:, :, :, :3]
|
||||
t = vae.encode(encode_pixels)
|
||||
return encode_pixels, t
|
||||
|
||||
@ -40,10 +40,21 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
|
||||
|
||||
inverse_mask = torch.ones_like(mask) - mask
|
||||
|
||||
source_portion = mask * source[..., :visible_height, :visible_width]
|
||||
destination_portion = inverse_mask * destination[..., top:bottom, left:right]
|
||||
source_rgb = source[:, :3, :visible_height, :visible_width]
|
||||
dest_slice = destination[..., top:bottom, left:right]
|
||||
|
||||
if destination.shape[1] == 4:
|
||||
if torch.max(dest_slice) == 0:
|
||||
destination[:, :3, top:bottom, left:right] = source_rgb
|
||||
destination[:, 3:4, top:bottom, left:right] = mask
|
||||
else:
|
||||
destination[:, :3, top:bottom, left:right] = (mask * source_rgb) + (inverse_mask * dest_slice[:, :3])
|
||||
destination[:, 3:4, top:bottom, left:right] = torch.max(mask, dest_slice[:, 3:4])
|
||||
else:
|
||||
source_portion = mask * source_rgb
|
||||
destination_portion = inverse_mask * dest_slice
|
||||
destination[..., top:bottom, left:right] = source_portion + destination_portion
|
||||
|
||||
destination[..., top:bottom, left:right] = source_portion + destination_portion
|
||||
return destination
|
||||
|
||||
class LatentCompositeMasked(IO.ComfyNode):
|
||||
@ -84,18 +95,23 @@ class ImageCompositeMasked(IO.ComfyNode):
|
||||
display_name="Image Composite Masked",
|
||||
category="image",
|
||||
inputs=[
|
||||
IO.Image.Input("destination"),
|
||||
IO.Image.Input("source"),
|
||||
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||
IO.Boolean.Input("resize_source", default=False),
|
||||
IO.Image.Input("destination", optional=True),
|
||||
IO.Mask.Input("mask", optional=True),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, destination, source, x, y, resize_source, mask = None) -> IO.NodeOutput:
|
||||
def execute(cls, source, x, y, resize_source, destination = None, mask = None) -> IO.NodeOutput:
|
||||
if destination is None: # transparent rgba
|
||||
B, H, W, C = source.shape
|
||||
destination = torch.zeros((B, H, W, 4), dtype=source.dtype, device=source.device)
|
||||
if C == 3:
|
||||
source = torch.nn.functional.pad(source, (0, 1), value=1.0)
|
||||
destination, source = node_helpers.image_alpha_fix(destination, source)
|
||||
destination = destination.clone().movedim(-1, 1)
|
||||
output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1)
|
||||
@ -381,7 +397,6 @@ class GrowMask(IO.ComfyNode):
|
||||
|
||||
expand_mask = execute # TODO: remove
|
||||
|
||||
|
||||
class ThresholdMask(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
|
||||
@ -63,7 +63,7 @@ class MathExpressionNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
autogrow = io.Autogrow.TemplateNames(
|
||||
input=io.MultiType.Input("value", [io.Float, io.Int]),
|
||||
input=io.MultiType.Input("value", [io.Float, io.Int, io.Boolean]),
|
||||
names=list(string.ascii_lowercase),
|
||||
min=1,
|
||||
)
|
||||
@ -82,6 +82,7 @@ class MathExpressionNode(io.ComfyNode):
|
||||
outputs=[
|
||||
io.Float.Output(display_name="FLOAT"),
|
||||
io.Int.Output(display_name="INT"),
|
||||
io.Boolean.Output(display_name="BOOL"),
|
||||
],
|
||||
)
|
||||
|
||||
@ -97,7 +98,7 @@ class MathExpressionNode(io.ComfyNode):
|
||||
|
||||
result = simple_eval(expression, names=context, functions=MATH_FUNCTIONS)
|
||||
# bool check must come first because bool is a subclass of int in Python
|
||||
if isinstance(result, bool) or not isinstance(result, (int, float)):
|
||||
if not isinstance(result, (int, float)):
|
||||
raise ValueError(
|
||||
f"Math Expression '{expression}' must evaluate to a numeric result, "
|
||||
f"got {type(result).__name__}: {result!r}"
|
||||
@ -106,7 +107,7 @@ class MathExpressionNode(io.ComfyNode):
|
||||
raise ValueError(
|
||||
f"Math Expression '{expression}' produced a non-finite result: {result}"
|
||||
)
|
||||
return io.NodeOutput(float(result), int(result))
|
||||
return io.NodeOutput(float(result), int(result), bool(result))
|
||||
|
||||
|
||||
class MathExtension(ComfyExtension):
|
||||
|
||||
971
comfy_extras/nodes_wandancer.py
Normal file
971
comfy_extras/nodes_wandancer.py
Normal file
@ -0,0 +1,971 @@
|
||||
import math
|
||||
import nodes
|
||||
import node_helpers
|
||||
import torch
|
||||
import torchaudio
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
import numpy as np
|
||||
import logging
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
import scipy.signal
|
||||
import scipy.ndimage
|
||||
import scipy.fft
|
||||
import scipy.sparse
|
||||
|
||||
# Audio Processing Functions - Derived from librosa (https://github.com/librosa/librosa)
|
||||
# Copyright (c) 2013--2023, librosa development team.
|
||||
|
||||
def mel_to_hz(mels, htk=False):
|
||||
"""Convert mel to Hz (slaney)"""
|
||||
mels = np.asanyarray(mels)
|
||||
if htk:
|
||||
return 700.0 * (10.0 ** (mels / 2595.0) - 1.0)
|
||||
f_min = 0.0
|
||||
f_sp = 200.0 / 3
|
||||
freqs = f_min + f_sp * mels
|
||||
min_log_hz = 1000.0
|
||||
min_log_mel = (min_log_hz - f_min) / f_sp
|
||||
logstep = np.log(6.4) / 27.0
|
||||
if mels.ndim:
|
||||
log_t = mels >= min_log_mel
|
||||
freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel))
|
||||
elif mels >= min_log_mel:
|
||||
freqs = min_log_hz * np.exp(logstep * (mels - min_log_mel))
|
||||
return freqs
|
||||
|
||||
def hz_to_mel(frequencies, htk=False):
|
||||
"""Convert Hz to mel (slaney)"""
|
||||
frequencies = np.asanyarray(frequencies)
|
||||
if htk:
|
||||
return 2595.0 * np.log10(1.0 + frequencies / 700.0)
|
||||
f_min = 0.0
|
||||
f_sp = 200.0 / 3
|
||||
mels = (frequencies - f_min) / f_sp
|
||||
min_log_hz = 1000.0
|
||||
min_log_mel = (min_log_hz - f_min) / f_sp
|
||||
logstep = np.log(6.4) / 27.0
|
||||
if frequencies.ndim:
|
||||
log_t = frequencies >= min_log_hz
|
||||
mels[log_t] = min_log_mel + np.log(frequencies[log_t] / min_log_hz) / logstep
|
||||
elif frequencies >= min_log_hz:
|
||||
mels = min_log_mel + np.log(frequencies / min_log_hz) / logstep
|
||||
return mels
|
||||
|
||||
def compute_cqt(y, sr=22050, hop_length=512, fmin=None, n_bins=84, bins_per_octave=12, tuning=0.0):
|
||||
"""Compute Constant-Q Transform (CQT) spectrogram."""
|
||||
|
||||
def _relative_bandwidth(freqs):
|
||||
bpo = np.empty_like(freqs)
|
||||
logf = np.log2(freqs)
|
||||
bpo[0] = 1.0 / (logf[1] - logf[0])
|
||||
bpo[-1] = 1.0 / (logf[-1] - logf[-2])
|
||||
bpo[1:-1] = 2.0 / (logf[2:] - logf[:-2])
|
||||
return (2.0 ** (2.0 / bpo) - 1.0) / (2.0 ** (2.0 / bpo) + 1.0)
|
||||
|
||||
def _wavelet_lengths(freqs, sr, filter_scale, alpha):
|
||||
Q = float(filter_scale) / alpha
|
||||
return Q * sr / freqs # shape (n_bins,) floats
|
||||
|
||||
def _build_wavelet(freqs_oct, sr, filter_scale, alpha_oct):
|
||||
lengths = _wavelet_lengths(freqs_oct, sr, filter_scale, alpha_oct)
|
||||
filters = []
|
||||
for ilen, freq in zip(lengths, freqs_oct):
|
||||
t = np.arange(int(-ilen // 2), int(ilen // 2), dtype=float)
|
||||
sig = (np.cos(t * 2 * np.pi * freq / sr)
|
||||
+ 1j * np.sin(t * 2 * np.pi * freq / sr)).astype(np.complex64)
|
||||
sig *= scipy.signal.get_window('hann', len(sig), fftbins=True)
|
||||
l1 = np.sum(np.abs(sig))
|
||||
tiny = np.finfo(np.float32).tiny
|
||||
sig /= max(l1, tiny)
|
||||
filters.append(sig)
|
||||
max_len = max(lengths)
|
||||
n_fft = int(2.0 ** np.ceil(np.log2(max_len)))
|
||||
out = np.zeros((len(filters), n_fft), dtype=np.complex64)
|
||||
for k, f in enumerate(filters):
|
||||
lpad = int((n_fft - len(f)) // 2)
|
||||
out[k, lpad: lpad + len(f)] = f
|
||||
return out, lengths
|
||||
|
||||
def _resample_half(y):
|
||||
ratio = 0.5
|
||||
n_samples = int(np.ceil(len(y) * ratio))
|
||||
# Kaiser-windowed FIR matches librosa/soxr more closely than scipy's default Hamming filter
|
||||
L = 2
|
||||
h = scipy.signal.firwin(160 * L + 1, 0.96 / L, window=('kaiser', 6.5))
|
||||
y_hat = scipy.signal.resample_poly(y.astype(np.float32), 1, 2, window=h)
|
||||
if len(y_hat) > n_samples:
|
||||
y_hat = y_hat[:n_samples]
|
||||
elif len(y_hat) < n_samples:
|
||||
y_hat = np.pad(y_hat, (0, n_samples - len(y_hat)))
|
||||
y_hat /= np.sqrt(ratio)
|
||||
return y_hat.astype(np.float32)
|
||||
|
||||
def _sparsify_rows(x, quantile=0.01):
|
||||
mags = np.abs(x)
|
||||
norms = np.sum(mags, axis=1, keepdims=True)
|
||||
norms = np.where(norms == 0, 1.0, norms)
|
||||
mag_sort = np.sort(mags, axis=1)
|
||||
cumulative_mag = np.cumsum(mag_sort / norms, axis=1)
|
||||
threshold_idx = np.argmin(cumulative_mag < quantile, axis=1)
|
||||
x_sparse = scipy.sparse.lil_matrix(x.shape, dtype=x.dtype)
|
||||
for i, j in enumerate(threshold_idx):
|
||||
idx = np.where(mags[i] >= mag_sort[i, j])
|
||||
x_sparse[i, idx] = x[i, idx]
|
||||
return x_sparse.tocsr()
|
||||
|
||||
if fmin is None:
|
||||
fmin = 32.70319566257483 # C1 note frequency
|
||||
|
||||
fmin = fmin * (2.0 ** (tuning / bins_per_octave))
|
||||
freqs = fmin * (2.0 ** (np.arange(n_bins) / bins_per_octave))
|
||||
|
||||
alpha = _relative_bandwidth(freqs)
|
||||
lengths = _wavelet_lengths(freqs, float(sr), 1, alpha)
|
||||
|
||||
n_octaves = int(np.ceil(float(n_bins) / bins_per_octave))
|
||||
n_filters = min(bins_per_octave, n_bins)
|
||||
|
||||
cqt_resp = []
|
||||
my_y = y.astype(np.float32)
|
||||
my_sr = float(sr)
|
||||
my_hop = int(hop_length)
|
||||
|
||||
for i in range(n_octaves):
|
||||
if i == 0:
|
||||
sl = slice(-n_filters, None)
|
||||
else:
|
||||
sl = slice(-n_filters * (i + 1), -n_filters * i)
|
||||
|
||||
freqs_oct = freqs[sl]
|
||||
alpha_oct = alpha[sl]
|
||||
|
||||
basis, basis_lengths = _build_wavelet(freqs_oct, my_sr, 1, alpha_oct)
|
||||
n_fft_oct = basis.shape[1]
|
||||
|
||||
# Frequency-domain normalisation
|
||||
basis = basis.astype(np.complex64)
|
||||
basis *= basis_lengths[:, np.newaxis] / float(n_fft_oct)
|
||||
fft_basis = scipy.fft.fft(basis, n=n_fft_oct, axis=1)[:, :(n_fft_oct // 2) + 1]
|
||||
fft_basis = _sparsify_rows(fft_basis, quantile=0.01)
|
||||
fft_basis = fft_basis * np.sqrt(sr / my_sr)
|
||||
|
||||
y_pad = np.pad(my_y, int(n_fft_oct // 2), mode='constant')
|
||||
n_frames = 1 + (len(y_pad) - n_fft_oct) // my_hop
|
||||
frames = np.lib.stride_tricks.as_strided(
|
||||
y_pad,
|
||||
shape=(n_fft_oct, n_frames),
|
||||
strides=(y_pad.strides[0], y_pad.strides[0] * my_hop),
|
||||
)
|
||||
stft_result = scipy.fft.rfft(frames, axis=0)
|
||||
cqt_resp.append(fft_basis.dot(stft_result))
|
||||
|
||||
if my_hop % 2 == 0:
|
||||
my_hop //= 2
|
||||
my_sr /= 2.0
|
||||
my_y = _resample_half(my_y)
|
||||
|
||||
max_col = min(c.shape[-1] for c in cqt_resp)
|
||||
cqt_out = np.empty((n_bins, max_col), dtype=np.complex64)
|
||||
end = n_bins
|
||||
for c_i in cqt_resp:
|
||||
n_oct = c_i.shape[0]
|
||||
if end < n_oct:
|
||||
cqt_out[:end, :] = c_i[-end:, :max_col]
|
||||
else:
|
||||
cqt_out[end - n_oct:end, :] = c_i[:, :max_col]
|
||||
end -= n_oct
|
||||
|
||||
cqt_out /= np.sqrt(lengths)[:, np.newaxis]
|
||||
return np.abs(cqt_out).astype(np.float32)
|
||||
|
||||
|
||||
def cq_to_chroma_mapping(n_input, bins_per_octave=12, n_chroma=12, fmin=None):
|
||||
"""Map CQT bins to chroma bins."""
|
||||
|
||||
if fmin is None:
|
||||
fmin = 32.70319566257483 # C1 note frequency
|
||||
|
||||
n_merge = bins_per_octave / n_chroma
|
||||
cq_to_ch = np.repeat(np.eye(n_chroma), int(n_merge), axis=1)
|
||||
cq_to_ch = np.roll(cq_to_ch, -int(n_merge // 2), axis=1)
|
||||
n_octaves = int(np.ceil(n_input / bins_per_octave))
|
||||
cq_to_ch = np.tile(cq_to_ch, n_octaves)[:, :n_input]
|
||||
|
||||
midi_0 = np.mod(12 * np.log2(fmin / 440.0) + 69, 12)
|
||||
roll = int(np.round(midi_0 * (n_chroma / 12.0)))
|
||||
cq_to_ch = np.roll(cq_to_ch, roll, axis=0)
|
||||
|
||||
return cq_to_ch.astype(np.float32)
|
||||
|
||||
|
||||
def _parabolic_interpolation(S, axis=-2):
|
||||
"""Compute parabolic interpolation shift for peak refinement."""
|
||||
S_next = np.roll(S, -1, axis=axis)
|
||||
S_prev = np.roll(S, 1, axis=axis)
|
||||
|
||||
a = S_next + S_prev - 2 * S
|
||||
b = (S_next - S_prev) / 2.0
|
||||
|
||||
shifts = np.zeros_like(S)
|
||||
valid = np.abs(b) < np.abs(a)
|
||||
shifts[valid] = -b[valid] / a[valid]
|
||||
|
||||
if axis == -2 or axis == S.ndim - 2:
|
||||
shifts[0, :] = 0
|
||||
shifts[-1, :] = 0
|
||||
elif axis == 0:
|
||||
shifts[0, ...] = 0
|
||||
shifts[-1, ...] = 0
|
||||
|
||||
return shifts
|
||||
|
||||
|
||||
def _localmax(S, axis=-2):
|
||||
"""Find local maxima along an axis."""
|
||||
|
||||
S_prev = np.roll(S, 1, axis=axis)
|
||||
S_next = np.roll(S, -1, axis=axis)
|
||||
|
||||
local_max = (S > S_prev) & (S >= S_next)
|
||||
|
||||
if axis == -2 or axis == S.ndim - 2:
|
||||
local_max[-1, :] = S[-1, :] > S[-2, :]
|
||||
# First element is never a local max (strict inequality with previous)
|
||||
local_max[0, :] = False
|
||||
elif axis == 0:
|
||||
local_max[-1, ...] = S[-1, ...] > S[-2, ...]
|
||||
local_max[0, ...] = False
|
||||
|
||||
return local_max
|
||||
|
||||
|
||||
def piptrack(y=None, sr=22050, S=None, n_fft=2048, hop_length=512,
|
||||
fmin=150.0, fmax=4000.0, threshold=0.1):
|
||||
"""Pitch tracking on thresholded parabolically-interpolated STFT."""
|
||||
|
||||
# Compute STFT if not provided
|
||||
if S is None:
|
||||
if y is None:
|
||||
raise ValueError("Either y or S must be provided")
|
||||
|
||||
fft_window = scipy.signal.get_window('hann', n_fft, fftbins=True)
|
||||
if len(fft_window) < n_fft:
|
||||
lpad = int((n_fft - len(fft_window)) // 2)
|
||||
fft_window = np.pad(fft_window, (lpad, int(n_fft - len(fft_window) - lpad)), mode='constant')
|
||||
fft_window = fft_window.reshape((-1, 1))
|
||||
|
||||
y_pad = np.pad(y, int(n_fft // 2), mode='constant')
|
||||
n_frames = 1 + (len(y_pad) - n_fft) // hop_length
|
||||
frames = np.lib.stride_tricks.as_strided(
|
||||
y_pad,
|
||||
shape=(n_fft, n_frames),
|
||||
strides=(y_pad.strides[0], y_pad.strides[0] * hop_length)
|
||||
)
|
||||
|
||||
S = scipy.fft.rfft((fft_window * frames).astype(np.float32), axis=0)
|
||||
|
||||
S = np.abs(S)
|
||||
|
||||
fmin = max(fmin, 0)
|
||||
fmax = min(fmax, float(sr) / 2)
|
||||
|
||||
fft_freqs = np.fft.rfftfreq(S.shape[0] * 2 - 2, 1.0 / sr)
|
||||
if len(fft_freqs) > S.shape[0]:
|
||||
fft_freqs = fft_freqs[:S.shape[0]]
|
||||
|
||||
shift = _parabolic_interpolation(S, axis=0)
|
||||
avg = np.gradient(S, axis=0)
|
||||
dskew = 0.5 * avg * shift
|
||||
|
||||
pitches = np.zeros_like(S)
|
||||
mags = np.zeros_like(S)
|
||||
|
||||
freq_mask = (fmin <= fft_freqs) & (fft_freqs < fmax)
|
||||
freq_mask = freq_mask.reshape(-1, 1)
|
||||
|
||||
ref_value = threshold * np.max(S, axis=0, keepdims=True)
|
||||
local_max = _localmax(S * (S > ref_value), axis=0)
|
||||
idx = np.nonzero(freq_mask & local_max)
|
||||
|
||||
pitches[idx] = (idx[0] + shift[idx]) * float(sr) / (S.shape[0] * 2 - 2)
|
||||
mags[idx] = S[idx] + dskew[idx]
|
||||
|
||||
return pitches, mags
|
||||
|
||||
|
||||
def hz_to_octs(frequencies, tuning=0.0, bins_per_octave=12):
|
||||
"""Convert frequencies (Hz) to octave numbers."""
|
||||
|
||||
A440 = 440.0 * 2.0 ** (tuning / bins_per_octave)
|
||||
octs = np.log2(np.asanyarray(frequencies) / (float(A440) / 16))
|
||||
return octs
|
||||
|
||||
|
||||
def pitch_tuning(frequencies, resolution=0.01, bins_per_octave=12):
|
||||
"""Estimate tuning offset from a collection of pitches."""
|
||||
|
||||
frequencies = np.atleast_1d(frequencies)
|
||||
frequencies = frequencies[frequencies > 0]
|
||||
|
||||
if not np.any(frequencies):
|
||||
return 0.0
|
||||
|
||||
residual = np.mod(bins_per_octave * hz_to_octs(frequencies, tuning=0.0,
|
||||
bins_per_octave=bins_per_octave), 1.0)
|
||||
residual[residual >= 0.5] -= 1.0
|
||||
|
||||
bins = np.linspace(-0.5, 0.5, int(np.ceil(1.0 / resolution)) + 1)
|
||||
counts, tuning = np.histogram(residual, bins)
|
||||
tuning_est = tuning[np.argmax(counts)]
|
||||
return tuning_est
|
||||
|
||||
|
||||
def estimate_tuning(y, sr=22050, bins_per_octave=12):
|
||||
"""Estimate global tuning deviation from 12-TET."""
|
||||
n_fft = 2048
|
||||
hop_length = 512
|
||||
|
||||
if len(y) < n_fft:
|
||||
return 0.0
|
||||
|
||||
pitch, mag = piptrack(y=y, sr=sr, n_fft=n_fft, hop_length=hop_length,
|
||||
fmin=150.0, fmax=4000.0, threshold=0.1)
|
||||
|
||||
pitch_mask = pitch > 0
|
||||
|
||||
if not pitch_mask.any():
|
||||
return 0.0
|
||||
|
||||
threshold = np.median(mag[pitch_mask])
|
||||
valid_pitches = pitch[(mag >= threshold) & pitch_mask]
|
||||
|
||||
if len(valid_pitches) == 0:
|
||||
return 0.0
|
||||
|
||||
tuning = pitch_tuning(valid_pitches, resolution=0.01, bins_per_octave=bins_per_octave)
|
||||
|
||||
return float(tuning)
|
||||
|
||||
|
||||
def compute_chroma_cens(y, sr=22050, hop_length=512, n_chroma=12,
|
||||
n_octaves=7, bins_per_octave=36,
|
||||
win_len_smooth=41, norm=2):
|
||||
"""Compute Chroma Energy Normalized Statistics (CENS) features."""
|
||||
|
||||
tuning = estimate_tuning(y, sr, bins_per_octave=bins_per_octave)
|
||||
|
||||
fmin = 32.70319566257483 # C1 note frequency
|
||||
n_bins = n_octaves * bins_per_octave
|
||||
cqt_mag = compute_cqt(y, sr=sr, hop_length=hop_length,
|
||||
fmin=fmin, n_bins=n_bins,
|
||||
bins_per_octave=bins_per_octave,
|
||||
tuning=tuning)
|
||||
|
||||
chroma_map = cq_to_chroma_mapping(n_bins, bins_per_octave=bins_per_octave,
|
||||
n_chroma=n_chroma, fmin=fmin)
|
||||
chroma = np.dot(chroma_map, cqt_mag)
|
||||
|
||||
threshold = np.finfo(chroma.dtype).tiny
|
||||
chroma_sum = np.sum(np.abs(chroma), axis=0, keepdims=True)
|
||||
chroma_sum = np.maximum(chroma_sum, threshold)
|
||||
chroma = chroma / chroma_sum
|
||||
|
||||
quant_steps = [0.4, 0.2, 0.1, 0.05]
|
||||
quant_weights = [0.25, 0.25, 0.25, 0.25]
|
||||
chroma_quant = np.zeros_like(chroma)
|
||||
for step, weight in zip(quant_steps, quant_weights):
|
||||
chroma_quant += (chroma > step) * weight
|
||||
|
||||
if win_len_smooth is not None and win_len_smooth > 0:
|
||||
win = scipy.signal.get_window('hann', win_len_smooth + 2, fftbins=False)
|
||||
win /= np.sum(win)
|
||||
win = win.reshape(1, -1)
|
||||
chroma_smooth = scipy.ndimage.convolve(chroma_quant, win, mode='constant')
|
||||
else:
|
||||
chroma_smooth = chroma_quant
|
||||
|
||||
if norm == 2:
|
||||
threshold = np.finfo(chroma_smooth.dtype).tiny
|
||||
chroma_norm = np.sqrt(np.sum(chroma_smooth ** 2, axis=0, keepdims=True))
|
||||
chroma_norm = np.maximum(chroma_norm, threshold)
|
||||
chroma_smooth = chroma_smooth / chroma_norm
|
||||
elif norm == np.inf:
|
||||
threshold = np.finfo(chroma_smooth.dtype).tiny
|
||||
chroma_norm = np.max(np.abs(chroma_smooth), axis=0, keepdims=True)
|
||||
chroma_norm = np.maximum(chroma_norm, threshold)
|
||||
chroma_smooth = chroma_smooth / chroma_norm
|
||||
|
||||
return chroma_smooth
|
||||
|
||||
|
||||
def _create_mel_filterbank(sr, n_fft, n_mels=128, fmin=0.0, fmax=None):
|
||||
"""Create mel-scale filterbank matrix."""
|
||||
if fmax is None:
|
||||
fmax = sr / 2.0
|
||||
mel_basis = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=np.float32)
|
||||
fftfreqs = np.fft.rfftfreq(n=n_fft, d=1.0 / sr)
|
||||
min_mel = hz_to_mel(fmin)
|
||||
max_mel = hz_to_mel(fmax)
|
||||
mels = np.linspace(min_mel, max_mel, n_mels + 2)
|
||||
mel_f = mel_to_hz(mels)
|
||||
fdiff = np.diff(mel_f)
|
||||
ramps = np.subtract.outer(mel_f, fftfreqs)
|
||||
|
||||
for i in range(n_mels):
|
||||
lower = -ramps[i] / fdiff[i]
|
||||
upper = ramps[i + 2] / fdiff[i + 1]
|
||||
mel_basis[i] = np.maximum(0, np.minimum(lower, upper))
|
||||
|
||||
enorm = 2.0 / (mel_f[2:n_mels + 2] - mel_f[:n_mels])
|
||||
mel_basis *= enorm[:, np.newaxis]
|
||||
return mel_basis
|
||||
|
||||
|
||||
def _compute_mel_spectrogram(data, sr, n_fft=2048, hop_length=512, n_mels=128):
|
||||
"""Compute mel spectrogram from audio signal."""
|
||||
fft_window = scipy.signal.get_window('hann', n_fft, fftbins=True)
|
||||
if len(fft_window) < n_fft:
|
||||
lpad = int((n_fft - len(fft_window)) // 2)
|
||||
fft_window = np.pad(fft_window, (lpad, int(n_fft - len(fft_window) - lpad)), mode='constant')
|
||||
|
||||
fft_window = fft_window.reshape((-1, 1))
|
||||
data_padded = np.pad(data, int(n_fft // 2), mode='constant')
|
||||
n_frames = 1 + (len(data_padded) - n_fft) // hop_length
|
||||
shape = (n_fft, n_frames)
|
||||
strides = (data_padded.strides[0], data_padded.strides[0] * hop_length)
|
||||
frames = np.lib.stride_tricks.as_strided(data_padded, shape=shape, strides=strides)
|
||||
|
||||
stft_result = scipy.fft.rfft(fft_window * frames, axis=0).astype(np.complex64)
|
||||
power_spec = np.abs(stft_result) ** 2
|
||||
|
||||
mel_basis = _create_mel_filterbank(sr, n_fft, n_mels=n_mels, fmin=0.0, fmax=sr / 2.0)
|
||||
mel_spec = np.dot(mel_basis, power_spec)
|
||||
return mel_spec.astype(np.float32)
|
||||
|
||||
|
||||
def quick_tempo_estimate(audio_np, sr, start_bpm=120.0, std_bpm=1.0, hop_length=512):
|
||||
"""Estimate tempo using autocorrelation tempogram."""
|
||||
|
||||
if len(audio_np) < hop_length * 10:
|
||||
logging.warning("Audio too short for tempo estimation, returning default BPM of 120.0")
|
||||
return 120.0
|
||||
|
||||
n_fft = 2048
|
||||
mel_S = _compute_mel_spectrogram(audio_np, sr, n_fft=n_fft, hop_length=hop_length, n_mels=128)
|
||||
log_mel_S = 10.0 * np.log10(np.maximum(1e-10, mel_S))
|
||||
|
||||
lag = 1
|
||||
S_diff = log_mel_S[:, lag:] - log_mel_S[:, :-lag]
|
||||
S_onset = np.maximum(0.0, S_diff)
|
||||
onset_env_pre = np.mean(S_onset, axis=0)
|
||||
pad_width = lag + n_fft // (2 * hop_length)
|
||||
onset_env = np.pad(onset_env_pre, (pad_width, 0), mode='constant')
|
||||
onset_env = onset_env[:mel_S.shape[1]]
|
||||
|
||||
return estimate_tempo_from_onset(onset_env, sr, hop_length, start_bpm, std_bpm, max_tempo=320.0)
|
||||
|
||||
|
||||
def estimate_tempo_from_onset(onset_env, sr, hop_length, start_bpm=120.0, std_bpm=1.0, max_tempo=320.0):
|
||||
"""Estimate tempo from onset strength envelope using autocorrelation tempogram."""
|
||||
if len(onset_env) < 20:
|
||||
return 120.0
|
||||
|
||||
ac_size = 8.0
|
||||
win_length = int(np.round(ac_size * sr / hop_length))
|
||||
win_length = min(win_length, len(onset_env))
|
||||
|
||||
pad_width = win_length // 2
|
||||
onset_padded = np.pad(onset_env, (pad_width, pad_width), mode='linear_ramp', end_values=(0, 0))
|
||||
|
||||
n_frames = len(onset_env)
|
||||
shape = (win_length, n_frames)
|
||||
strides = (onset_padded.strides[0], onset_padded.strides[0])
|
||||
frames = np.lib.stride_tricks.as_strided(onset_padded, shape=shape, strides=strides)
|
||||
|
||||
hann_window = scipy.signal.get_window('hann', win_length, fftbins=True)
|
||||
windowed_frames = frames * hann_window[:, np.newaxis]
|
||||
|
||||
tempogram = np.zeros((win_length, n_frames))
|
||||
for i in range(n_frames):
|
||||
frame = windowed_frames[:, i]
|
||||
n_pad = scipy.fft.next_fast_len(2 * len(frame) - 1)
|
||||
fft_result = scipy.fft.rfft(frame, n=n_pad)
|
||||
powspec = np.abs(fft_result) ** 2
|
||||
ac = scipy.fft.irfft(powspec, n=n_pad)
|
||||
tempogram[:, i] = ac[:win_length]
|
||||
|
||||
ac_max = np.max(np.abs(tempogram), axis=0)
|
||||
mask = ac_max > 0
|
||||
tempogram[:, mask] /= ac_max[mask]
|
||||
|
||||
tempogram_mean = np.mean(tempogram, axis=1)
|
||||
tempogram_mean = np.maximum(tempogram_mean, 0)
|
||||
|
||||
bpms = np.zeros(win_length, dtype=np.float64)
|
||||
bpms[0] = np.inf
|
||||
bpms[1:] = 60.0 * sr / (hop_length * np.arange(1.0, win_length))
|
||||
|
||||
logprior = -0.5 * ((np.log2(bpms) - np.log2(start_bpm)) / std_bpm) ** 2
|
||||
|
||||
if max_tempo is not None:
|
||||
max_idx = int(np.argmax(bpms < max_tempo))
|
||||
if max_idx > 0:
|
||||
logprior[:max_idx] = -np.inf
|
||||
|
||||
weighted = np.log1p(1e6 * tempogram_mean) + logprior
|
||||
best_idx = int(np.argmax(weighted[1:])) + 1
|
||||
tempo = bpms[best_idx]
|
||||
|
||||
return tempo
|
||||
|
||||
|
||||
def detect_onset_peaks(onset_env, sr=22050, hop_length=512, pre_max=0.03, post_max=0.0,
|
||||
pre_avg=0.10, post_avg=0.10, wait=0.03, delta=0.07):
|
||||
"""Detect onset peaks using peak picking algorithm."""
|
||||
|
||||
onset_normalized = onset_env - np.min(onset_env)
|
||||
onset_max = np.max(onset_normalized)
|
||||
if onset_max > 0:
|
||||
onset_normalized = onset_normalized / onset_max
|
||||
|
||||
pre_max_frames = int(pre_max * sr / hop_length)
|
||||
post_max_frames = int(post_max * sr / hop_length) + 1
|
||||
pre_avg_frames = int(pre_avg * sr / hop_length)
|
||||
post_avg_frames = int(post_avg * sr / hop_length) + 1
|
||||
wait_frames = int(wait * sr / hop_length)
|
||||
|
||||
peaks = np.zeros(len(onset_normalized), dtype=bool)
|
||||
peaks[0] = (onset_normalized[0] >= np.max(onset_normalized[:min(post_max_frames, len(onset_normalized))]))
|
||||
peaks[0] &= (onset_normalized[0] >= np.mean(onset_normalized[:min(post_avg_frames, len(onset_normalized))]) + delta)
|
||||
|
||||
if peaks[0]:
|
||||
n = wait_frames + 1
|
||||
else:
|
||||
n = 1
|
||||
|
||||
while n < len(onset_normalized):
|
||||
maxn = np.max(onset_normalized[max(0, n - pre_max_frames):min(n + post_max_frames, len(onset_normalized))])
|
||||
peaks[n] = (onset_normalized[n] == maxn)
|
||||
|
||||
if not peaks[n]:
|
||||
n += 1
|
||||
continue
|
||||
|
||||
avgn = np.mean(onset_normalized[max(0, n - pre_avg_frames):min(n + post_avg_frames, len(onset_normalized))])
|
||||
peaks[n] &= (onset_normalized[n] >= avgn + delta)
|
||||
|
||||
if not peaks[n]:
|
||||
n += 1
|
||||
continue
|
||||
|
||||
n += wait_frames + 1
|
||||
|
||||
return np.flatnonzero(peaks).astype(np.int32)
|
||||
|
||||
|
||||
def track_beats(onset_env, tempo, sr, hop_length, tightness=100, trim=True):
|
||||
"""Track beats using dynamic programming."""
|
||||
|
||||
frame_rate = sr / hop_length
|
||||
frames_per_beat = np.round(frame_rate * 60.0 / tempo)
|
||||
|
||||
if frames_per_beat <= 0 or len(onset_env) < 2:
|
||||
return np.array([], dtype=np.int32)
|
||||
|
||||
onset_std = np.std(onset_env, ddof=1)
|
||||
if onset_std > 0:
|
||||
onset_normalized = onset_env / onset_std
|
||||
else:
|
||||
onset_normalized = onset_env
|
||||
|
||||
window_range = np.arange(-frames_per_beat, frames_per_beat + 1)
|
||||
window = np.exp(-0.5 * (window_range * 32.0 / frames_per_beat) ** 2)
|
||||
|
||||
localscore = scipy.signal.convolve(onset_normalized, window, mode='same')
|
||||
|
||||
backlink = np.full(len(localscore), -1, dtype=np.int32)
|
||||
cumscore = np.zeros(len(localscore), dtype=np.float64)
|
||||
|
||||
score_thresh = 0.01 * localscore.max()
|
||||
first_beat = True
|
||||
|
||||
backlink[0] = -1
|
||||
cumscore[0] = localscore[0]
|
||||
|
||||
fpb = int(frames_per_beat)
|
||||
|
||||
for i in range(1, len(localscore)):
|
||||
score_i = localscore[i]
|
||||
best_score = -np.inf
|
||||
beat_location = -1
|
||||
|
||||
search_start = int(i - np.round(fpb / 2.0))
|
||||
search_end = int(i - 2 * fpb - 1)
|
||||
|
||||
for loc in range(search_start, search_end, -1):
|
||||
if loc < 0:
|
||||
break
|
||||
|
||||
score = cumscore[loc] - tightness * (np.log(i - loc) - np.log(fpb)) ** 2
|
||||
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
beat_location = loc
|
||||
|
||||
if beat_location >= 0:
|
||||
cumscore[i] = score_i + best_score
|
||||
else:
|
||||
cumscore[i] = score_i
|
||||
|
||||
if first_beat and score_i < score_thresh:
|
||||
backlink[i] = -1
|
||||
else:
|
||||
backlink[i] = beat_location
|
||||
first_beat = False
|
||||
|
||||
local_max_mask = np.zeros(len(cumscore), dtype=bool)
|
||||
|
||||
local_max_mask[0] = False
|
||||
|
||||
for i in range(1, len(cumscore) - 1):
|
||||
local_max_mask[i] = (cumscore[i] > cumscore[i-1]) and (cumscore[i] >= cumscore[i+1])
|
||||
|
||||
if len(cumscore) > 1:
|
||||
local_max_mask[-1] = cumscore[-1] > cumscore[-2]
|
||||
|
||||
if np.any(local_max_mask):
|
||||
median_max = np.median(cumscore[local_max_mask])
|
||||
threshold = 0.5 * median_max
|
||||
|
||||
tail = -1
|
||||
for i in range(len(cumscore) - 1, -1, -1):
|
||||
if local_max_mask[i] and cumscore[i] >= threshold:
|
||||
tail = i
|
||||
break
|
||||
else:
|
||||
tail = len(cumscore) - 1
|
||||
|
||||
beats = np.zeros(len(localscore), dtype=bool)
|
||||
n = tail
|
||||
visited = set()
|
||||
while n >= 0 and n not in visited:
|
||||
beats[n] = True
|
||||
visited.add(n)
|
||||
n = backlink[n]
|
||||
|
||||
if trim and np.any(beats):
|
||||
beat_positions = np.flatnonzero(beats)
|
||||
|
||||
beat_localscores = localscore[beat_positions]
|
||||
|
||||
w = np.hanning(5)
|
||||
smooth_boe_full = np.convolve(beat_localscores, w)
|
||||
smooth_boe = smooth_boe_full[len(w)//2 : len(localscore) + len(w)//2]
|
||||
|
||||
threshold = 0.5 * np.sqrt(np.mean(smooth_boe ** 2))
|
||||
|
||||
start_frame = 0
|
||||
while start_frame < len(localscore) and localscore[start_frame] <= threshold:
|
||||
beats[start_frame] = False
|
||||
start_frame += 1
|
||||
|
||||
end_frame = len(localscore) - 1
|
||||
while end_frame >= 0 and localscore[end_frame] <= threshold:
|
||||
beats[end_frame] = False
|
||||
end_frame -= 1
|
||||
|
||||
return np.flatnonzero(beats).astype(np.int32)
|
||||
|
||||
def compute_onset_envelope(mel_spec_db, n_fft=2048, hop_length=512):
|
||||
"""Compute onset strength envelope from a log-mel spectrogram (dB)."""
|
||||
lag = 1
|
||||
onset_diff = mel_spec_db[:, lag:] - mel_spec_db[:, :-lag]
|
||||
onset_diff = np.maximum(0.0, onset_diff)
|
||||
envelope_pre_pad = np.mean(onset_diff, axis=0)
|
||||
|
||||
pad_width = lag + n_fft // (2 * hop_length)
|
||||
envelope = np.pad(envelope_pre_pad, (pad_width, 0), mode='constant')
|
||||
envelope = envelope[:mel_spec_db.shape[1]]
|
||||
|
||||
return envelope
|
||||
|
||||
def compute_mfcc(mel_spec_db, n_mfcc=20):
|
||||
"""Compute MFCC features from a log-mel spectrogram (dB)."""
|
||||
mfcc = scipy.fft.dct(mel_spec_db, axis=0, type=2, norm='ortho')[:n_mfcc].T
|
||||
return mfcc.astype(np.float32)
|
||||
|
||||
|
||||
def power_to_db(S, amin=1e-10, top_db=80.0, ref=1.0):
|
||||
"""Convert a power spectrogram (amplitude squared) to decibel (dB) units"""
|
||||
S = np.asarray(S)
|
||||
log_spec = 10.0 * np.log10(np.maximum(amin, S))
|
||||
log_spec -= 10.0 * np.log10(np.maximum(amin, ref))
|
||||
if top_db is not None:
|
||||
log_spec = np.maximum(log_spec, log_spec.max() - top_db)
|
||||
return log_spec
|
||||
|
||||
|
||||
class WanDancerEncodeAudio(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanDancerEncodeAudio",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Audio.Input("audio"),
|
||||
io.Int.Input("video_frames", default=149, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Float.Input("audio_inject_scale", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="The scale for the audio features when injected into the video model."),
|
||||
],
|
||||
outputs=[
|
||||
io.AudioEncoderOutput.Output(display_name="audio_encoder_output"),
|
||||
io.String.Output(display_name="fps_string", tooltip="The calculated fps based on the audio length and the number of video frames. Used in the prompt."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, video_frames, audio_inject_scale, audio) -> io.NodeOutput:
|
||||
waveform = audio["waveform"][0]
|
||||
sample_rate = audio["sample_rate"]
|
||||
base_fps = 30
|
||||
hop_length = 512
|
||||
model_sr = 22050
|
||||
n_fft = 2048
|
||||
|
||||
# start tempo from original audio (not the resampled one) to match the reference pipeline
|
||||
if waveform.shape[0] > 1:
|
||||
waveform = waveform.mean(dim=0, keepdim=False)
|
||||
|
||||
start_bpm = quick_tempo_estimate(waveform.squeeze().cpu().numpy(), sample_rate, hop_length=hop_length)
|
||||
|
||||
# resample to the sample rate used for feature extraction
|
||||
resample_sr = base_fps * hop_length
|
||||
waveform = torchaudio.functional.resample(waveform, sample_rate, resample_sr)
|
||||
|
||||
waveform_np = waveform.cpu().numpy().squeeze()
|
||||
mel_spec = _compute_mel_spectrogram(waveform_np, model_sr, n_fft, hop_length, n_mels=128)
|
||||
mel_spec_db = power_to_db(mel_spec, amin=1e-10, top_db=80.0, ref=1.0)
|
||||
envelope = compute_onset_envelope(mel_spec_db, n_fft, hop_length)
|
||||
mfcc = compute_mfcc(mel_spec_db, n_mfcc=20)
|
||||
chroma = compute_chroma_cens(y=waveform_np, sr=model_sr, hop_length=hop_length).T
|
||||
# detect peaks
|
||||
peak_idxs = detect_onset_peaks(envelope, sr=model_sr, hop_length=hop_length)
|
||||
peak_onehot = np.zeros_like(envelope, dtype=np.float32)
|
||||
peak_onehot[peak_idxs] = 1.0
|
||||
# detect beats
|
||||
beat_tracking_tempo = estimate_tempo_from_onset(envelope, sr=model_sr, hop_length=hop_length, start_bpm=start_bpm)
|
||||
beat_idxs = track_beats(envelope, beat_tracking_tempo, model_sr, hop_length, tightness=100, trim=True)
|
||||
beat_onehot = np.zeros_like(envelope, dtype=np.float32)
|
||||
beat_onehot[beat_idxs] = 1.0
|
||||
|
||||
audio_feature = np.concatenate(
|
||||
[envelope[:, None], mfcc, chroma, peak_onehot[:, None], beat_onehot[:, None]],
|
||||
axis=-1,
|
||||
)
|
||||
audio_feature = torch.from_numpy(audio_feature).unsqueeze(0).to(comfy.model_management.intermediate_device())
|
||||
|
||||
fps = float(base_fps / int(audio_feature.shape[1] / video_frames + 0.5))
|
||||
|
||||
audio_encoder_output = {
|
||||
"audio_feature": audio_feature,
|
||||
"fps": fps,
|
||||
"audio_inject_scale": audio_inject_scale,
|
||||
}
|
||||
|
||||
if int(fps + 0.5) != 30:
|
||||
fps_string = " 帧率是{:.4f}".format(fps) # "frame rate is" in Chinese, as it was in the original pipeline
|
||||
else:
|
||||
fps_string = ", 帧率是30fps。" # to match the reference pipeline when the fps is 30
|
||||
|
||||
return io.NodeOutput(audio_encoder_output, fps_string)
|
||||
|
||||
|
||||
class WanDancerVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanDancerVideo",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=149, min=1, max=nodes.MAX_RESOLUTION, step=4, tooltip="The number of frames in the generated video. Should stay 149 for WanDancer."),
|
||||
io.ClipVisionOutput.Input("clip_vision_output", optional=True, tooltip="The CLIP vision embeds for the first frame."),
|
||||
io.ClipVisionOutput.Input("clip_vision_output_ref", optional=True, tooltip="The CLIP vision embeds for the reference image."),
|
||||
io.Image.Input("start_image", optional=True, tooltip="The initial image(s) to be encoded, can be any number of frames."),
|
||||
io.Mask.Input("mask", optional=True, tooltip="Image conditioning mask for the start image(s). White is kept, black is generated. Used for the local generations."),
|
||||
io.AudioEncoderOutput.Input("audio_encoder_output", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent", tooltip="Empty latent."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, start_image=None, mask=None, clip_vision_output=None, clip_vision_output_ref=None, audio_encoder_output=None) -> io.NodeOutput:
|
||||
latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
image = torch.zeros((length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
|
||||
image[:start_image.shape[0]] = start_image
|
||||
|
||||
concat_latent_image = vae.encode(image[:, :, :, :3])
|
||||
if mask is None:
|
||||
concat_mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
|
||||
concat_mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
|
||||
else:
|
||||
concat_mask = 1 - mask[:length].unsqueeze(0)
|
||||
concat_mask = comfy.utils.common_upscale(concat_mask, concat_latent_image.shape[-2], concat_latent_image.shape[-1], "nearest-exact", "disabled")
|
||||
concat_mask = torch.cat([torch.repeat_interleave(concat_mask[:, 0:1], repeats=4, dim=1), concat_mask[:, 1:]], dim=1)
|
||||
concat_mask = concat_mask.view(1, concat_mask.shape[1] // 4, 4, concat_latent_image.shape[-2], concat_latent_image.shape[-1]).transpose(1, 2)
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": concat_mask})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": concat_mask})
|
||||
|
||||
if clip_vision_output is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output, "clip_vision_output_ref": clip_vision_output_ref})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output, "clip_vision_output_ref": clip_vision_output_ref})
|
||||
|
||||
if audio_encoder_output is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"audio_embed": audio_encoder_output["audio_feature"], "fps": audio_encoder_output["fps"], "audio_inject_scale": audio_encoder_output.get("audio_inject_scale", 1.0)})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"audio_embed": audio_encoder_output["audio_feature"], "fps": audio_encoder_output["fps"], "audio_inject_scale": audio_encoder_output.get("audio_inject_scale", 1.0)})
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return io.NodeOutput(positive, negative, out_latent)
|
||||
|
||||
|
||||
class WanDancerPadKeyframes(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanDancerPadKeyframes",
|
||||
category="image/video",
|
||||
inputs=[
|
||||
io.Image.Input("images",),
|
||||
io.Int.Input("segment_length", default=149, min=1, max=10000, tooltip="Length of this segment (usually 149 frames)"),
|
||||
io.Int.Input("segment_index", default=0, min=0, max=100, tooltip="Which segment this is (0 for first, 1 for second, etc.)"),
|
||||
io.Audio.Input("audio", tooltip="Audio to calculate total output frames from and extract segment audio."),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(display_name="keyframes_sequence", tooltip="Padded keyframe sequence"),
|
||||
io.Mask.Output(display_name="keyframes_mask", tooltip="Mask indicating valid frames"),
|
||||
io.Audio.Output(display_name="audio_segment", tooltip="Audio segment for this video segment"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def do_execute(cls, images, segment_length, segment_index, audio):
|
||||
B, H, W, C = images.shape
|
||||
fps = 30
|
||||
|
||||
# calculate total frames
|
||||
audio_duration = audio["waveform"].shape[-1] / audio["sample_rate"]
|
||||
segment_duration = segment_length / fps
|
||||
buffer = 0.2
|
||||
num_segments = int((audio_duration - buffer) / segment_duration) + 1 if audio_duration > buffer else 0
|
||||
total_frames = num_segments * segment_length
|
||||
|
||||
mask = torch.zeros((segment_length, H, W), device=images.device, dtype=images.dtype)
|
||||
keyframes = torch.zeros((segment_length, H, W, C), dtype=images.dtype, device=images.device)
|
||||
|
||||
# guard: with no audio or no images, nothing to place — leave keyframes/mask zeroed
|
||||
if total_frames > 0 and B > 0:
|
||||
frame_interval = float(total_frames) / B
|
||||
seg_num = int(math.ceil(total_frames / segment_length))
|
||||
is_last_segment = (segment_index == seg_num - 1)
|
||||
|
||||
positions = []
|
||||
images_before_this_segment = 0
|
||||
|
||||
# count images consumed by previous segments
|
||||
for seg_idx in range(segment_index):
|
||||
end_idx = (total_frames - segment_length * seg_idx - 1) if seg_idx == seg_num - 1 else (segment_length - 1)
|
||||
cnt = 0
|
||||
while cnt * frame_interval < end_idx - frame_interval:
|
||||
cnt += 1
|
||||
images_before_this_segment += cnt
|
||||
|
||||
# positions for current segment
|
||||
end_index = (total_frames - segment_length * segment_index - 1) if is_last_segment else (segment_length - 1)
|
||||
cnt = 0
|
||||
while cnt * frame_interval < end_index - frame_interval:
|
||||
pos = int(math.ceil(frame_interval * cnt))
|
||||
positions.append((pos, images_before_this_segment + cnt))
|
||||
cnt += 1
|
||||
positions.append((end_index, images_before_this_segment + cnt))
|
||||
|
||||
valid_positions = [(pos, idx) for pos, idx in positions if idx < B and pos < segment_length]
|
||||
|
||||
if valid_positions:
|
||||
seg_positions, img_indices = zip(*valid_positions)
|
||||
seg_positions = torch.tensor(seg_positions, dtype=torch.long, device=images.device)
|
||||
img_indices = torch.tensor(img_indices, dtype=torch.long, device=images.device)
|
||||
mask[seg_positions] = 1
|
||||
keyframes[seg_positions] = images[img_indices]
|
||||
|
||||
# extract audio segment
|
||||
segment_duration = segment_length / fps
|
||||
start_time = segment_index * segment_duration
|
||||
end_time = min(start_time + segment_duration, audio_duration)
|
||||
|
||||
sample_rate = audio["sample_rate"]
|
||||
start_sample = int(start_time * sample_rate)
|
||||
end_sample = int(end_time * sample_rate)
|
||||
|
||||
audio_segment_waveform = audio["waveform"][:, :, start_sample:end_sample]
|
||||
audio_segment = {
|
||||
"waveform": audio_segment_waveform,
|
||||
"sample_rate": sample_rate
|
||||
}
|
||||
|
||||
return keyframes, mask, audio_segment
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images, segment_length, segment_index, audio=None) -> io.NodeOutput:
|
||||
return io.NodeOutput(*cls.do_execute(images, segment_length, segment_index, audio))
|
||||
|
||||
class WanDancerPadKeyframesList(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanDancerPadKeyframesList",
|
||||
category="image/video",
|
||||
inputs=[
|
||||
io.Image.Input("images"),
|
||||
io.Int.Input("segment_length", default=149, min=1, max=10000, tooltip="Length of each segment (usually 149 frames)"),
|
||||
io.Int.Input("num_segments", default=1, min=1, max=100, tooltip="How many padded segments to emit as lists."),
|
||||
io.Audio.Input("audio", tooltip="Audio to slice for each emitted segment."),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(display_name="keyframes_sequence", tooltip="Padded keyframe sequences", is_output_list=True),
|
||||
io.Mask.Output(display_name="keyframes_mask", tooltip="Masks indicating valid frames", is_output_list=True),
|
||||
io.Audio.Output(display_name="audio_segment", tooltip="Audio segment for each video segment", is_output_list=True),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images, segment_length, num_segments, audio=None) -> io.NodeOutput:
|
||||
outputs = [WanDancerPadKeyframes.do_execute(images, segment_length, i, audio) for i in range(num_segments)]
|
||||
keyframes, masks, audio_segments = zip(*outputs)
|
||||
return io.NodeOutput(list(keyframes), list(masks), list(audio_segments))
|
||||
|
||||
class WanDancerExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
WanDancerVideo,
|
||||
WanDancerEncodeAudio,
|
||||
WanDancerPadKeyframes,
|
||||
WanDancerPadKeyframesList,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> WanDancerExtension:
|
||||
return WanDancerExtension()
|
||||
@ -52,6 +52,8 @@ folder_names_and_paths["model_patches"] = ([os.path.join(models_dir, "model_patc
|
||||
|
||||
folder_names_and_paths["audio_encoders"] = ([os.path.join(models_dir, "audio_encoders")], supported_pt_extensions)
|
||||
|
||||
folder_names_and_paths["background_removal"] = ([os.path.join(models_dir, "background_removal")], supported_pt_extensions)
|
||||
|
||||
folder_names_and_paths["frame_interpolation"] = ([os.path.join(models_dir, "frame_interpolation")], supported_pt_extensions)
|
||||
|
||||
folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions)
|
||||
|
||||
2
nodes.py
2
nodes.py
@ -2429,10 +2429,12 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_number_convert.py",
|
||||
"nodes_painter.py",
|
||||
"nodes_curve.py",
|
||||
"nodes_bg_removal.py",
|
||||
"nodes_rtdetr.py",
|
||||
"nodes_frame_interpolation.py",
|
||||
"nodes_sam3.py",
|
||||
"nodes_void.py",
|
||||
"nodes_wandancer.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
4712
openapi.yaml
4712
openapi.yaml
File diff suppressed because it is too large
Load Diff
@ -1,4 +1,4 @@
|
||||
comfyui-frontend-package==1.43.17
|
||||
comfyui-frontend-package==1.43.18
|
||||
comfyui-workflow-templates==0.9.72
|
||||
comfyui-embedded-docs==0.4.4
|
||||
torch
|
||||
|
||||
@ -124,9 +124,11 @@ class TestMathExpressionExecute:
|
||||
with pytest.raises(Exception, match="not defined"):
|
||||
self._exec("str(a)", a=42)
|
||||
|
||||
def test_boolean_result_raises(self):
|
||||
with pytest.raises(ValueError, match="got bool"):
|
||||
self._exec("a > b", a=5, b=3)
|
||||
def test_boolean_result(self):
|
||||
result = self._exec("a > b", a=5, b=3)
|
||||
assert result[2] is True
|
||||
result = self._exec("a > b", a=3, b=5)
|
||||
assert result[2] is False
|
||||
|
||||
def test_empty_expression_raises(self):
|
||||
with pytest.raises(ValueError, match="Expression cannot be empty"):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user