Compare commits

..

1 Commits

Author SHA1 Message Date
Ray Suhyun Lee
241d37314d
Merge b947b5a4a3 into 594de378fe 2026-05-08 08:16:09 +02:00
81 changed files with 218 additions and 7140 deletions

View File

@ -431,10 +431,9 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image Tools/Color adjust",
"description": "Adjusts image brightness and contrast using a real-time GPU fragment shader."
"category": "Image Tools/Color adjust"
}
]
},
"extra": {}
}
}

View File

@ -162,7 +162,7 @@
},
"revision": 0,
"config": {},
"name": "Canny to Image (Z-Image-Turbo)",
"name": "local-Canny to Image (Z-Image-Turbo)",
"inputNode": {
"id": -10,
"bounding": [
@ -1553,8 +1553,7 @@
"VHS_MetadataImage": true,
"VHS_KeepIntermediate": true
},
"category": "Image generation and editing/Canny to image",
"description": "Generates an image from a Canny edge map using Z-Image-Turbo, with text conditioning."
"category": "Image generation and editing/Canny to image"
}
]
},
@ -1575,4 +1574,4 @@
}
},
"version": 0.4
}
}

View File

@ -192,7 +192,7 @@
},
"revision": 0,
"config": {},
"name": "Canny to Video (LTX 2.0)",
"name": "local-Canny to Video (LTX 2.0)",
"inputNode": {
"id": -10,
"bounding": [
@ -3600,8 +3600,7 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Video generation and editing/Canny to video",
"description": "Generates video from Canny edge maps using LTX-2, with optional synchronized audio."
"category": "Video generation and editing/Canny to video"
}
]
},
@ -3617,4 +3616,4 @@
}
},
"version": 0.4
}
}

View File

@ -377,9 +377,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image Tools/Color adjust",
"description": "Adds lens-style chromatic aberration (color fringing) using a real-time GPU fragment shader."
"category": "Image Tools/Color adjust"
}
]
}
}
}

View File

@ -596,8 +596,7 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image Tools/Color adjust",
"description": "Adjusts saturation, temperature, tint, and vibrance using a real-time GPU fragment shader."
"category": "Image Tools/Color adjust"
}
]
}

View File

@ -1129,8 +1129,7 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image Tools/Color adjust",
"description": "Balances colors across shadows, midtones, and highlights using a real-time GPU fragment shader."
"category": "Image Tools/Color adjust"
}
]
}

View File

@ -608,8 +608,7 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image Tools/Color adjust",
"description": "Fine-tunes tone and color with per-channel curve adjustments using a real-time GPU fragment shader."
"category": "Image Tools/Color adjust"
}
]
}

View File

@ -1609,8 +1609,7 @@
}
],
"extra": {},
"category": "Image Tools/Crop",
"description": "Splits an image into a 2×2 grid of four equal tiles."
"category": "Image Tools/Crop"
}
]
},

View File

@ -2946,8 +2946,7 @@
}
],
"extra": {},
"category": "Image Tools/Crop",
"description": "Splits an image into a 3×3 grid of nine equal tiles."
"category": "Image Tools/Crop"
}
]
},

View File

@ -1579,8 +1579,7 @@
"VHS_MetadataImage": true,
"VHS_KeepIntermediate": true
},
"category": "Image generation and editing/Depth to image",
"description": "Generates an image from a depth map using Z-Image-Turbo with text conditioning."
"category": "Image generation and editing/Depth to image"
},
{
"id": "458bdf3c-4b58-421c-af50-c9c663a4d74c",
@ -2462,8 +2461,7 @@
]
},
"workflowRendererVersion": "LG"
},
"description": "Estimates a monocular depth map from an input image using the Lotus depth estimation model."
}
}
]
},

View File

@ -4233,8 +4233,7 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Video generation and editing/Depth to video",
"description": "Generates video from depth maps using LTX-2, with optional synchronized audio."
"category": "Video generation and editing/Depth to video"
},
{
"id": "38b60539-50a7-42f9-a5fe-bdeca26272e2",
@ -5193,8 +5192,7 @@
],
"extra": {
"workflowRendererVersion": "LG"
},
"description": "Estimates a monocular depth map from an input image using the Lotus depth estimation model."
}
}
]
},

View File

@ -450,10 +450,9 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image Tools/Blur",
"description": "Applies bilateral (edge-preserving) blur to soften images while retaining detail."
"category": "Image Tools/Blur"
}
]
},
"extra": {}
}
}

View File

@ -580,9 +580,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image Tools/Color adjust",
"description": "Adds procedural film grain texture for a cinematic look via GPU fragment shader."
"category": "Image Tools/Color adjust"
}
]
}
}
}

View File

@ -3350,8 +3350,7 @@
}
],
"extra": {},
"category": "Video generation and editing/First-Last-Frame to Video",
"description": "Generates a video interpolating between first and last keyframes using LTX-2.3."
"category": "Video generation and editing/First-Last-Frame to Video"
}
]
},

View File

@ -575,9 +575,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image Tools/Color adjust",
"description": "Adds a glow/bloom effect around bright image areas via GPU fragment shader."
"category": "Image Tools/Color adjust"
}
]
}
}
}

View File

@ -752,9 +752,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image Tools/Color adjust",
"description": "Adjusts hue, saturation, and lightness of an image using a real-time GPU fragment shader."
"category": "Image Tools/Color adjust"
}
]
}
}
}

View File

@ -374,8 +374,7 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image Tools/Blur",
"description": "Applies Gaussian, Box, or Radial blur to soften images and create stylized depth or motion effects."
"category": "Image Tools/Blur"
}
]
}

View File

@ -310,8 +310,7 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Text generation/Image Captioning",
"description": "Generates descriptive captions for images using Google's Gemini multimodal LLM."
"category": "Text generation/Image Captioning"
}
]
}

View File

@ -315,9 +315,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image Tools/Color adjust",
"description": "Manipulates individual RGBA channels for masking, compositing, and channel effects."
"category": "Image Tools/Color adjust"
}
]
}
}
}

View File

@ -2138,8 +2138,7 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image generation and editing/Edit image",
"description": "Edits images via text instructions using FireRed Image Edit 1.1, a diffusion-based instruction-following editing model."
"category": "Image generation and editing/Edit image"
}
]
},

View File

@ -1472,8 +1472,7 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image generation and editing/Edit image",
"description": "Edits an input image via text instructions using FLUX.2 [klein] 4B."
"category": "Image generation and editing/Edit image"
},
{
"id": "6007e698-2ebd-4917-84d8-299b35d7b7ab",
@ -1822,8 +1821,7 @@
],
"extra": {
"workflowRendererVersion": "LG"
},
"description": "Applies reference image conditioning for style/identity transfer (Flux.2 Klein 4B)."
}
}
]
},
@ -1839,4 +1837,4 @@
}
},
"version": 0.4
}
}

View File

@ -1417,8 +1417,7 @@
}
],
"extra": {},
"category": "Image generation and editing/Edit image",
"description": "Edits images via text instructions using LongCat Image Edit, an instruction-following image editing diffusion model."
"category": "Image generation and editing/Edit image"
}
]
},

View File

@ -132,7 +132,7 @@
},
"revision": 0,
"config": {},
"name": "Image Edit (Qwen 2511)",
"name": "local-Image Edit (Qwen 2511)",
"inputNode": {
"id": -10,
"bounding": [
@ -1468,8 +1468,7 @@
"VHS_MetadataImage": true,
"VHS_KeepIntermediate": true
},
"category": "Image generation and editing/Edit image",
"description": "Edits images via text instructions using Qwen-Image-Edit-2511 with improved character consistency and integrated LoRA."
"category": "Image generation and editing/Edit image"
}
]
},
@ -1490,4 +1489,4 @@
}
},
"version": 0.4
}
}

View File

@ -1188,8 +1188,7 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image generation and editing/Inpaint image",
"description": "Inpaints masked image regions using Flux.1 fill [dev], Black Forest Labs' inpainting/outpainting model."
"category": "Image generation and editing/Inpaint image"
}
]
},
@ -1203,4 +1202,4 @@
},
"ue_links": []
}
}
}

View File

@ -1548,8 +1548,7 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image generation and editing/Inpaint image",
"description": "Inpaints masked regions using Qwen-Image, extending its multilingual text rendering to inpainting tasks."
"category": "Image generation and editing/Inpaint image"
},
{
"id": "56a1f603-fbd2-40ed-94ef-c9ecbd96aca8",
@ -1908,8 +1907,7 @@
],
"extra": {
"workflowRendererVersion": "LG"
},
"description": "Expands and softens mask edges to reduce visible seams after image processing."
}
}
]
},

View File

@ -742,10 +742,9 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image Tools/Color adjust",
"description": "Adjusts black point, white point, and gamma for tonal range control via GPU shader."
"category": "Image Tools/Color adjust"
}
]
},
"extra": {}
}
}

View File

@ -1919,8 +1919,7 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image generation and editing/Outpaint image",
"description": "Outpaints beyond image boundaries using Qwen-Image's outpainting capabilities."
"category": "Image generation and editing/Outpaint image"
},
{
"id": "f93c215e-c393-460e-9534-ed2c3d8a652e",
@ -2279,8 +2278,7 @@
],
"extra": {
"workflowRendererVersion": "LG"
},
"description": "Expands and softens mask edges to reduce visible seams after image processing."
}
},
{
"id": "2a4b2cc0-db37-4302-a067-da392f38f06b",
@ -2735,8 +2733,7 @@
],
"extra": {
"workflowRendererVersion": "LG"
},
"description": "Scales both image and mask together while preserving alignment for editing workflows."
}
}
]
},

View File

@ -141,7 +141,7 @@
},
"revision": 0,
"config": {},
"name": "Image Upscale (Z-image-Turbo)",
"name": "local-Image Upscale(Z-image-Turbo)",
"inputNode": {
"id": -10,
"bounding": [
@ -1302,8 +1302,7 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image generation and editing/Enhance",
"description": "Upscales images to higher resolution using Z-Image-Turbo."
"category": "Image generation and editing/Enhance"
}
]
},

View File

@ -99,7 +99,7 @@
},
"revision": 0,
"config": {},
"name": "Image to Depth Map (Lotus)",
"name": "local-Image to Depth Map (Lotus)",
"inputNode": {
"id": -10,
"bounding": [
@ -948,8 +948,7 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image generation and editing/Depth to image",
"description": "Estimates a monocular depth map from an input image using the Lotus depth estimation model."
"category": "Image generation and editing/Depth to image"
}
]
},
@ -965,4 +964,4 @@
"workflowRendererVersion": "LG"
},
"version": 0.4
}
}

View File

@ -1586,8 +1586,7 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image generation and editing/Image to layers",
"description": "Decomposes an image into variable-resolution RGBA layers for independent editing using Qwen-Image-Layered."
"category": "Image generation and editing/Image to layers"
}
]
},

View File

@ -72,7 +72,7 @@
},
"revision": 0,
"config": {},
"name": "Image to 3D Model (Hunyuan3d 2.1)",
"name": "local-Image to Model (Hunyuan3d 2.1)",
"inputNode": {
"id": -10,
"bounding": [
@ -765,8 +765,7 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "3D/Image to 3D Model",
"description": "Generates 3D mesh models from a single input image using Hunyuan3D 2.0/2.1."
"category": "3D/Image to 3D Model"
}
]
},

View File

@ -4223,8 +4223,7 @@
"extra": {
"workflowRendererVersion": "Vue-corrected"
},
"category": "Video generation and editing/Image to video",
"description": "Generates video from a single input image using LTX-2.3."
"category": "Video generation and editing/Image to video"
}
]
},

View File

@ -206,7 +206,7 @@
},
"revision": 0,
"config": {},
"name": "Image to Video (Wan 2.2)",
"name": "local-Image to Video (Wan 2.2)",
"inputNode": {
"id": -10,
"bounding": [
@ -2027,8 +2027,7 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Video generation and editing/Image to video",
"description": "Generates video from an image and text prompt using Wan 2.2, supporting T2V and I2V."
"category": "Video generation and editing/Image to video"
}
]
},

View File

@ -134,7 +134,7 @@
},
"revision": 0,
"config": {},
"name": "Pose to Image (Z-Image-Turbo)",
"name": "local-Pose to Image (Z-Image-Turbo)",
"inputNode": {
"id": -10,
"bounding": [
@ -1298,8 +1298,7 @@
"VHS_MetadataImage": true,
"VHS_KeepIntermediate": true
},
"category": "Image generation and editing/Pose to image",
"description": "Generates an image from pose keypoints using Z-Image-Turbo with text conditioning."
"category": "Image generation and editing/Pose to image"
}
]
},
@ -1320,4 +1319,4 @@
}
},
"version": 0.4
}
}

View File

@ -3870,8 +3870,7 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Video generation and editing/Pose to video",
"description": "Generates video from pose reference frames using LTX-2, with optional synchronized audio."
"category": "Video generation and editing/Pose to video"
}
]
},

View File

@ -270,10 +270,9 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Text generation/Prompt enhance",
"description": "Expands short text prompts into detailed descriptions using a text generation model for better generation quality."
"category": "Text generation/Prompt enhance"
}
]
},
"extra": {}
}
}

View File

@ -302,9 +302,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image Tools/Sharpen",
"description": "Sharpens image details using a GPU fragment shader for enhanced clarity."
"category": "Image Tools/Sharpen"
}
]
}
}
}

View File

@ -222,7 +222,7 @@
},
"revision": 0,
"config": {},
"name": "Text to Audio (ACE-Step 1.5)",
"name": "local-Text to Audio (ACE-Step 1.5)",
"inputNode": {
"id": -10,
"bounding": [
@ -1502,8 +1502,7 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Audio/Music generation",
"description": "Generates audio/music from text prompts using ACE-Step 1.5, a diffusion-based audio generation model."
"category": "Audio/Music generation"
}
]
},
@ -1519,4 +1518,4 @@
}
},
"version": 0.4
}
}

View File

@ -1029,8 +1029,7 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image generation and editing/Text to image",
"description": "Generates images from text prompts using Flux.1 [dev], Black Forest Labs' 12B diffusion model."
"category": "Image generation and editing/Text to image"
}
]
},
@ -1044,4 +1043,4 @@
},
"ue_links": []
}
}
}

View File

@ -1023,8 +1023,7 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image generation and editing/Text to image",
"description": "Generates images from text prompts using Flux.1 Krea Dev, a Black Forest Labs × Krea collaboration variant."
"category": "Image generation and editing/Text to image"
}
]
},
@ -1038,4 +1037,4 @@
},
"ue_links": []
}
}
}

View File

@ -1104,8 +1104,7 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image generation and editing/Text to image",
"description": "Generates images from text prompts using NetaYume Lumina, fine-tuned from Neta Lumina for anime-style and illustration generation."
"category": "Image generation and editing/Text to image"
},
{
"id": "a07fdf06-1bda-4dac-bdbd-63ee8ebca1c9",
@ -1459,12 +1458,11 @@
],
"extra": {
"workflowRendererVersion": "LG"
},
"description": "Encodes a negative text prompt via CLIP for classifier-free guidance in anime-style generation (NetaYume Lumina)."
}
}
]
},
"extra": {
"ue_links": []
}
}
}

View File

@ -1941,8 +1941,7 @@
"extra": {
"workflowRendererVersion": "Vue-corrected"
},
"category": "Image generation and editing/Text to image",
"description": "Generates images from text prompts using Qwen-Image-2512, with enhanced human realism and finer natural detail over the base version."
"category": "Image generation and editing/Text to image"
}
]
},

View File

@ -1873,8 +1873,7 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image generation and editing/Text to image",
"description": "Generates images from text prompts using Qwen-Image, Alibaba's 20B MMDiT model with excellent multilingual text rendering."
"category": "Image generation and editing/Text to image"
}
]
},

View File

@ -149,7 +149,7 @@
},
"revision": 0,
"config": {},
"name": "Text to Image (Z-Image-Turbo)",
"name": "local-Text to Image (Z-Image-Turbo)",
"inputNode": {
"id": -10,
"bounding": [
@ -1054,8 +1054,7 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image generation and editing/Text to image",
"description": "Generates images from text prompts using Z-Image-Turbo, Alibaba's distilled 6B DiT model."
"category": "Image generation and editing/Text to image"
}
]
},
@ -1076,4 +1075,4 @@
}
},
"version": 0.4
}
}

View File

@ -4286,8 +4286,7 @@
"extra": {
"workflowRendererVersion": "Vue-corrected"
},
"category": "Video generation and editing/Text to video",
"description": "Generates video from text prompts using LTX-2.3, Lightricks' video diffusion model."
"category": "Video generation and editing/Text to video"
}
]
},

View File

@ -1572,8 +1572,7 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Video generation and editing/Text to video",
"description": "Generates video from text prompts using Wan2.2, Alibaba's diffusion video model."
"category": "Video generation and editing/Text to video"
}
]
},
@ -1587,4 +1586,4 @@
"VHS_KeepIntermediate": true
},
"version": 0.4
}
}

View File

@ -434,9 +434,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Image Tools/Sharpen",
"description": "Enhances edge contrast via unsharp masking for a sharper image appearance."
"category": "Image Tools/Sharpen"
}
]
}
}
}

View File

@ -307,8 +307,7 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Text generation/Video Captioning",
"description": "Generates descriptive captions for video input using Google's Gemini multimodal LLM."
"category": "Text generation/Video Captioning"
}
]
}

View File

@ -165,7 +165,7 @@
},
"revision": 0,
"config": {},
"name": "Video Inpaint (Wan 2.1 VACE)",
"name": "local-Video Inpaint(Wan2.1 VACE)",
"inputNode": {
"id": -10,
"bounding": [
@ -2368,8 +2368,7 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Video generation and editing/Inpaint video",
"description": "Inpaints masked regions in video frames using Wan 2.1 VACE."
"category": "Video generation and editing/Inpaint video"
}
]
},

View File

@ -584,9 +584,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Video Tools/Stitch videos",
"description": "Stitches multiple video clips into a single sequential video file."
"category": "Video Tools/Stitch videos"
}
]
}
}
}

View File

@ -412,10 +412,9 @@
"extra": {
"workflowRendererVersion": "LG"
},
"category": "Video generation and editing/Enhance video",
"description": "Upscales video to 4× resolution using a GAN-based upscaling model."
"category": "Video generation and editing/Enhance video"
}
]
},
"extra": {}
}
}

View File

@ -1,7 +0,0 @@
{
"model_type": "birefnet",
"image_std": [1.0, 1.0, 1.0],
"image_mean": [0.0, 0.0, 0.0],
"image_size": 1024,
"resize_to_original": true
}

View File

@ -1,689 +0,0 @@
import torch
import comfy.ops
import numpy as np
import torch.nn as nn
from functools import partial
import torch.nn.functional as F
from torchvision.ops import deform_conv2d
from comfy.ldm.modules.attention import optimized_attention_for_device
CXT = [3072, 1536, 768, 384][1:][::-1][-3:]
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, device=None, dtype=None, operations=None):
super().__init__()
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.q = operations.Linear(dim, dim, bias=qkv_bias, device=device, dtype=dtype)
self.kv = operations.Linear(dim, dim * 2, bias=qkv_bias, device=device, dtype=dtype)
self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
def forward(self, x):
B, N, C = x.shape
optimized_attention = optimized_attention_for_device(x.device, mask=False, small_input=True)
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
x = optimized_attention(
q, k, v, heads=self.num_heads, skip_output_reshape=True, skip_reshape=True
).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, device=None, dtype=None, operations=None):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = operations.Linear(in_features, hidden_features, device=device, dtype=dtype)
self.act = nn.GELU()
self.fc2 = operations.Linear(hidden_features, out_features, device=device, dtype=dtype)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
def window_partition(x, window_size):
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, device=None, dtype=None, operations=None):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads, device=device, dtype=dtype))
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, device=device, dtype=dtype)
self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.long().view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
return x
class SwinTransformerBlock(nn.Module):
def __init__(self, dim, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None,
norm_layer=nn.LayerNorm, device=None, dtype=None, operations=None):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
self.norm1 = norm_layer(dim, device=device, dtype=dtype)
self.attn = WindowAttention(
dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, device=device, dtype=dtype, operations=operations)
self.norm2 = norm_layer(dim, device=device, dtype=dtype)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, device=device, dtype=dtype, operations=operations)
self.H = None
self.W = None
def forward(self, x, mask_matrix):
B, L, C = x.shape
H, W = self.H, self.W
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
pad_l = pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
attn_mask = mask_matrix
else:
shifted_x = x
attn_mask = None
x_windows = window_partition(shifted_x, self.window_size)
x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
attn_windows = self.attn(x_windows, mask=attn_mask)
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x
class PatchMerging(nn.Module):
def __init__(self, dim, device=None, dtype=None, operations=None):
super().__init__()
self.dim = dim
self.reduction = operations.Linear(4 * dim, 2 * dim, bias=False, device=device, dtype=dtype)
self.norm = operations.LayerNorm(4 * dim, device=device, dtype=dtype)
def forward(self, x, H, W):
B, L, C = x.shape
x = x.view(B, H, W, C)
# padding
pad_input = (H % 2 == 1) or (W % 2 == 1)
if pad_input:
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
class BasicLayer(nn.Module):
def __init__(self,
dim,
depth,
num_heads,
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
norm_layer=nn.LayerNorm,
downsample=None,
device=None, dtype=None, operations=None):
super().__init__()
self.window_size = window_size
self.shift_size = window_size // 2
self.depth = depth
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(
dim=dim,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
norm_layer=norm_layer,
device=device, dtype=dtype, operations=operations)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(dim=dim, device=device, dtype=dtype, operations=operations)
else:
self.downsample = None
def forward(self, x, H, W):
Hp = int(np.ceil(H / self.window_size)) * self.window_size
Wp = int(np.ceil(W / self.window_size)) * self.window_size
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size)
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
for blk in self.blocks:
blk.H, blk.W = H, W
x = blk(x, attn_mask)
if self.downsample is not None:
x_down = self.downsample(x, H, W)
Wh, Ww = (H + 1) // 2, (W + 1) // 2
return x, H, W, x_down, Wh, Ww
else:
return x, H, W, x, H, W
class PatchEmbed(nn.Module):
def __init__(self, patch_size=4, in_channels=3, embed_dim=96, norm_layer=None, device=None, dtype=None, operations=None):
super().__init__()
patch_size = (patch_size, patch_size)
self.patch_size = patch_size
self.in_channels = in_channels
self.embed_dim = embed_dim
self.proj = operations.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype)
if norm_layer is not None:
self.norm = norm_layer(embed_dim, device=device, dtype=dtype)
else:
self.norm = None
def forward(self, x):
_, _, H, W = x.size()
if W % self.patch_size[1] != 0:
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
if H % self.patch_size[0] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
x = self.proj(x) # B C Wh Ww
if self.norm is not None:
Wh, Ww = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
return x
class SwinTransformer(nn.Module):
def __init__(self,
pretrain_img_size=224,
patch_size=4,
in_channels=3,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
patch_norm=True,
out_indices=(0, 1, 2, 3),
frozen_stages=-1,
device=None, dtype=None, operations=None):
super().__init__()
norm_layer = partial(operations.LayerNorm, device=device, dtype=dtype)
self.pretrain_img_size = pretrain_img_size
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.patch_norm = patch_norm
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.patch_embed = PatchEmbed(
patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim,
device=device, dtype=dtype, operations=operations,
norm_layer=norm_layer if self.patch_norm else None)
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(
dim=int(embed_dim * 2 ** i_layer),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
device=device, dtype=dtype, operations=operations)
self.layers.append(layer)
num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
self.num_features = num_features
for i_layer in out_indices:
layer = norm_layer(num_features[i_layer])
layer_name = f'norm{i_layer}'
self.add_module(layer_name, layer)
def forward(self, x):
x = self.patch_embed(x)
Wh, Ww = x.size(2), x.size(3)
outs = []
x = x.flatten(2).transpose(1, 2)
for i in range(self.num_layers):
layer = self.layers[i]
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
x_out = norm_layer(x_out)
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
outs.append(out)
return tuple(outs)
class DeformableConv2d(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False, device=None, dtype=None, operations=None):
super(DeformableConv2d, self).__init__()
kernel_size = kernel_size if type(kernel_size) is tuple else (kernel_size, kernel_size)
self.stride = stride if type(stride) is tuple else (stride, stride)
self.padding = padding
self.offset_conv = operations.Conv2d(in_channels,
2 * kernel_size[0] * kernel_size[1],
kernel_size=kernel_size,
stride=stride,
padding=self.padding,
bias=True, device=device, dtype=dtype)
self.modulator_conv = operations.Conv2d(in_channels,
1 * kernel_size[0] * kernel_size[1],
kernel_size=kernel_size,
stride=stride,
padding=self.padding,
bias=True, device=device, dtype=dtype)
self.regular_conv = operations.Conv2d(in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=self.padding,
bias=bias, device=device, dtype=dtype)
def forward(self, x):
offset = self.offset_conv(x)
modulator = 2. * torch.sigmoid(self.modulator_conv(x))
weight, bias, offload_info = comfy.ops.cast_bias_weight(self.regular_conv, x, offloadable=True)
x = deform_conv2d(
input=x,
offset=offset,
weight=weight,
bias=None,
padding=self.padding,
mask=modulator,
stride=self.stride,
)
comfy.ops.uncast_bias_weight(self.regular_conv, weight, bias, offload_info)
return x
class BasicDecBlk(nn.Module):
def __init__(self, in_channels=64, out_channels=64, inter_channels=64, device=None, dtype=None, operations=None):
super(BasicDecBlk, self).__init__()
inter_channels = 64
self.conv_in = operations.Conv2d(in_channels, inter_channels, 3, 1, padding=1, device=device, dtype=dtype)
self.relu_in = nn.ReLU(inplace=True)
self.dec_att = ASPPDeformable(in_channels=inter_channels, device=device, dtype=dtype, operations=operations)
self.conv_out = operations.Conv2d(inter_channels, out_channels, 3, 1, padding=1, device=device, dtype=dtype)
self.bn_in = operations.BatchNorm2d(inter_channels, device=device, dtype=dtype)
self.bn_out = operations.BatchNorm2d(out_channels, device=device, dtype=dtype)
def forward(self, x):
x = self.conv_in(x)
x = self.bn_in(x)
x = self.relu_in(x)
x = self.dec_att(x)
x = self.conv_out(x)
x = self.bn_out(x)
return x
class BasicLatBlk(nn.Module):
def __init__(self, in_channels=64, out_channels=64, device=None, dtype=None, operations=None):
super(BasicLatBlk, self).__init__()
self.conv = operations.Conv2d(in_channels, out_channels, 1, 1, 0, device=device, dtype=dtype)
def forward(self, x):
x = self.conv(x)
return x
class _ASPPModuleDeformable(nn.Module):
def __init__(self, in_channels, planes, kernel_size, padding, device, dtype, operations):
super(_ASPPModuleDeformable, self).__init__()
self.atrous_conv = DeformableConv2d(in_channels, planes, kernel_size=kernel_size,
stride=1, padding=padding, bias=False, device=device, dtype=dtype, operations=operations)
self.bn = operations.BatchNorm2d(planes, device=device, dtype=dtype)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.atrous_conv(x)
x = self.bn(x)
return self.relu(x)
class ASPPDeformable(nn.Module):
def __init__(self, in_channels, out_channels=None, parallel_block_sizes=[1, 3, 7], device=None, dtype=None, operations=None):
super(ASPPDeformable, self).__init__()
self.down_scale = 1
if out_channels is None:
out_channels = in_channels
self.in_channelster = 256 // self.down_scale
self.aspp1 = _ASPPModuleDeformable(in_channels, self.in_channelster, 1, padding=0, device=device, dtype=dtype, operations=operations)
self.aspp_deforms = nn.ModuleList([
_ASPPModuleDeformable(in_channels, self.in_channelster, conv_size, padding=int(conv_size//2), device=device, dtype=dtype, operations=operations)
for conv_size in parallel_block_sizes
])
self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
operations.Conv2d(in_channels, self.in_channelster, 1, stride=1, bias=False, device=device, dtype=dtype),
operations.BatchNorm2d(self.in_channelster, device=device, dtype=dtype),
nn.ReLU(inplace=True))
self.conv1 = operations.Conv2d(self.in_channelster * (2 + len(self.aspp_deforms)), out_channels, 1, bias=False, device=device, dtype=dtype)
self.bn1 = operations.BatchNorm2d(out_channels, device=device, dtype=dtype)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x1 = self.aspp1(x)
x_aspp_deforms = [aspp_deform(x) for aspp_deform in self.aspp_deforms]
x5 = self.global_avg_pool(x)
x5 = F.interpolate(x5, size=x1.size()[2:], mode='bilinear', align_corners=True)
x = torch.cat((x1, *x_aspp_deforms, x5), dim=1)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
return x
class BiRefNet(nn.Module):
def __init__(self, config=None, dtype=None, device=None, operations=None):
super(BiRefNet, self).__init__()
self.bb = SwinTransformer(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12, device=device, dtype=dtype, operations=operations)
channels = [1536, 768, 384, 192]
channels = [c * 2 for c in channels]
self.cxt = channels[1:][::-1][-3:]
self.squeeze_module = nn.Sequential(*[
BasicDecBlk(channels[0]+sum(self.cxt), channels[0], device=device, dtype=dtype, operations=operations)
for _ in range(1)
])
self.decoder = Decoder(channels, device=device, dtype=dtype, operations=operations)
def forward_enc(self, x):
x1, x2, x3, x4 = self.bb(x)
B, C, H, W = x.shape
x1_, x2_, x3_, x4_ = self.bb(F.interpolate(x, size=(H//2, W//2), mode='bilinear', align_corners=True))
x1 = torch.cat([x1, F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True)], dim=1)
x2 = torch.cat([x2, F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True)], dim=1)
x3 = torch.cat([x3, F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True)], dim=1)
x4 = torch.cat([x4, F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True)], dim=1)
x4 = torch.cat(
(
*[
F.interpolate(x1, size=x4.shape[2:], mode='bilinear', align_corners=True),
F.interpolate(x2, size=x4.shape[2:], mode='bilinear', align_corners=True),
F.interpolate(x3, size=x4.shape[2:], mode='bilinear', align_corners=True),
][-len(CXT):],
x4
),
dim=1
)
return (x1, x2, x3, x4)
def forward_ori(self, x):
(x1, x2, x3, x4) = self.forward_enc(x)
x4 = self.squeeze_module(x4)
features = [x, x1, x2, x3, x4]
scaled_preds = self.decoder(features)
return scaled_preds
def forward(self, pixel_values, intermediate_output=None):
scaled_preds = self.forward_ori(pixel_values)
return scaled_preds
class Decoder(nn.Module):
def __init__(self, channels, device, dtype, operations):
super(Decoder, self).__init__()
# factory kwargs
fk = {"device":device, "dtype":dtype, "operations":operations}
DecoderBlock = partial(BasicDecBlk, **fk)
LateralBlock = partial(BasicLatBlk, **fk)
DBlock = partial(SimpleConvs, **fk)
self.split = True
N_dec_ipt = 64
ic = 64
ipt_cha_opt = 1
self.ipt_blk5 = DBlock(2**10*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic)
self.ipt_blk4 = DBlock(2**8*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic)
self.ipt_blk3 = DBlock(2**6*3 if self.split else 3, [N_dec_ipt, channels[1]//8][ipt_cha_opt], inter_channels=ic)
self.ipt_blk2 = DBlock(2**4*3 if self.split else 3, [N_dec_ipt, channels[2]//8][ipt_cha_opt], inter_channels=ic)
self.ipt_blk1 = DBlock(2**0*3 if self.split else 3, [N_dec_ipt, channels[3]//8][ipt_cha_opt], inter_channels=ic)
self.decoder_block4 = DecoderBlock(channels[0]+([N_dec_ipt, channels[0]//8][ipt_cha_opt]), channels[1])
self.decoder_block3 = DecoderBlock(channels[1]+([N_dec_ipt, channels[0]//8][ipt_cha_opt]), channels[2])
self.decoder_block2 = DecoderBlock(channels[2]+([N_dec_ipt, channels[1]//8][ipt_cha_opt]), channels[3])
self.decoder_block1 = DecoderBlock(channels[3]+([N_dec_ipt, channels[2]//8][ipt_cha_opt]), channels[3]//2)
fk = {"device":device, "dtype":dtype}
self.conv_out1 = nn.Sequential(operations.Conv2d(channels[3]//2+([N_dec_ipt, channels[3]//8][ipt_cha_opt]), 1, 1, 1, 0, **fk))
self.lateral_block4 = LateralBlock(channels[1], channels[1])
self.lateral_block3 = LateralBlock(channels[2], channels[2])
self.lateral_block2 = LateralBlock(channels[3], channels[3])
self.conv_ms_spvn_4 = operations.Conv2d(channels[1], 1, 1, 1, 0, **fk)
self.conv_ms_spvn_3 = operations.Conv2d(channels[2], 1, 1, 1, 0, **fk)
self.conv_ms_spvn_2 = operations.Conv2d(channels[3], 1, 1, 1, 0, **fk)
_N = 16
self.gdt_convs_4 = nn.Sequential(operations.Conv2d(channels[0] // 2, _N, 3, 1, 1, **fk), operations.BatchNorm2d(_N, **fk), nn.ReLU(inplace=True))
self.gdt_convs_3 = nn.Sequential(operations.Conv2d(channels[1] // 2, _N, 3, 1, 1, **fk), operations.BatchNorm2d(_N, **fk), nn.ReLU(inplace=True))
self.gdt_convs_2 = nn.Sequential(operations.Conv2d(channels[2] // 2, _N, 3, 1, 1, **fk), operations.BatchNorm2d(_N, **fk), nn.ReLU(inplace=True))
[setattr(self, f"gdt_convs_pred_{i}", nn.Sequential(operations.Conv2d(_N, 1, 1, 1, 0, **fk))) for i in range(2, 5)]
[setattr(self, f"gdt_convs_attn_{i}", nn.Sequential(operations.Conv2d(_N, 1, 1, 1, 0, **fk))) for i in range(2, 5)]
def get_patches_batch(self, x, p):
_size_h, _size_w = p.shape[2:]
patches_batch = []
for idx in range(x.shape[0]):
columns_x = torch.split(x[idx], split_size_or_sections=_size_w, dim=-1)
patches_x = []
for column_x in columns_x:
patches_x += [p.unsqueeze(0) for p in torch.split(column_x, split_size_or_sections=_size_h, dim=-2)]
patch_sample = torch.cat(patches_x, dim=1)
patches_batch.append(patch_sample)
return torch.cat(patches_batch, dim=0)
def forward(self, features):
x, x1, x2, x3, x4 = features
patches_batch = self.get_patches_batch(x, x4) if self.split else x
x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1)
p4 = self.decoder_block4(x4)
p4_gdt = self.gdt_convs_4(p4)
gdt_attn_4 = self.gdt_convs_attn_4(p4_gdt).sigmoid()
p4 = p4 * gdt_attn_4
_p4 = F.interpolate(p4, size=x3.shape[2:], mode='bilinear', align_corners=True)
_p3 = _p4 + self.lateral_block4(x3)
patches_batch = self.get_patches_batch(x, _p3) if self.split else x
_p3 = torch.cat((_p3, self.ipt_blk4(F.interpolate(patches_batch, size=x3.shape[2:], mode='bilinear', align_corners=True))), 1)
p3 = self.decoder_block3(_p3)
p3_gdt = self.gdt_convs_3(p3)
gdt_attn_3 = self.gdt_convs_attn_3(p3_gdt).sigmoid()
p3 = p3 * gdt_attn_3
_p3 = F.interpolate(p3, size=x2.shape[2:], mode='bilinear', align_corners=True)
_p2 = _p3 + self.lateral_block3(x2)
patches_batch = self.get_patches_batch(x, _p2) if self.split else x
_p2 = torch.cat((_p2, self.ipt_blk3(F.interpolate(patches_batch, size=x2.shape[2:], mode='bilinear', align_corners=True))), 1)
p2 = self.decoder_block2(_p2)
p2_gdt = self.gdt_convs_2(p2)
gdt_attn_2 = self.gdt_convs_attn_2(p2_gdt).sigmoid()
p2 = p2 * gdt_attn_2
_p2 = F.interpolate(p2, size=x1.shape[2:], mode='bilinear', align_corners=True)
_p1 = _p2 + self.lateral_block2(x1)
patches_batch = self.get_patches_batch(x, _p1) if self.split else x
_p1 = torch.cat((_p1, self.ipt_blk2(F.interpolate(patches_batch, size=x1.shape[2:], mode='bilinear', align_corners=True))), 1)
_p1 = self.decoder_block1(_p1)
_p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True)
patches_batch = self.get_patches_batch(x, _p1) if self.split else x
_p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1)
p1_out = self.conv_out1(_p1)
return p1_out
class SimpleConvs(nn.Module):
def __init__(
self, in_channels: int, out_channels: int, inter_channels=64, device=None, dtype=None, operations=None
) -> None:
super().__init__()
self.conv1 = operations.Conv2d(in_channels, inter_channels, 3, 1, 1, device=device, dtype=dtype)
self.conv_out = operations.Conv2d(inter_channels, out_channels, 3, 1, 1, device=device, dtype=dtype)
def forward(self, x):
return self.conv_out(self.conv1(x))

View File

@ -1,78 +0,0 @@
from .utils import load_torch_file
import os
import json
import torch
import logging
import comfy.ops
import comfy.model_patcher
import comfy.model_management
import comfy.clip_model
import comfy.background_removal.birefnet
BG_REMOVAL_MODELS = {
"birefnet": comfy.background_removal.birefnet.BiRefNet
}
class BackgroundRemovalModel():
def __init__(self, json_config):
with open(json_config) as f:
config = json.load(f)
self.image_size = config.get("image_size", 1024)
self.image_mean = config.get("image_mean", [0.0, 0.0, 0.0])
self.image_std = config.get("image_std", [1.0, 1.0, 1.0])
self.model_type = config.get("model_type", "birefnet")
self.config = config.copy()
model_class = BG_REMOVAL_MODELS.get(self.model_type)
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
self.model.eval()
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
def get_sd(self):
return self.model.state_dict()
def encode_image(self, image):
comfy.model_management.load_model_gpu(self.patcher)
H, W = image.shape[1], image.shape[2]
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=False)
out = self.model(pixel_values=pixel_values)
out = torch.nn.functional.interpolate(out, size=(H, W), mode="bicubic", antialias=False)
mask = out.sigmoid().to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
if mask.ndim == 3:
mask = mask.unsqueeze(0)
if mask.shape[1] != 1:
mask = mask.movedim(-1, 1)
return mask
def load_background_removal_model(sd):
if "bb.layers.1.blocks.0.attn.relative_position_index" in sd:
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "background_removal"), "birefnet.json")
else:
return None
bg_model = BackgroundRemovalModel(json_config)
m, u = bg_model.load_sd(sd)
if len(m) > 0:
logging.warning("missing background removal: {}".format(m))
u = set(u)
keys = list(sd.keys())
for k in keys:
if k not in u:
sd.pop(k)
return bg_model
def load(ckpt_path):
sd = load_torch_file(ckpt_path)
return load_background_removal_model(sd)

View File

@ -93,7 +93,7 @@ class Hook:
self.hook_scope = hook_scope
'''Scope of where this hook should apply in terms of the conds used in sampling run.'''
self.custom_should_register = default_should_register
'''Can be overridden with a compatible function to decide if this hook should be registered without the need to override .should_register'''
'''Can be overriden with a compatible function to decide if this hook should be registered without the need to override .should_register'''
@property
def strength(self):

View File

@ -140,7 +140,7 @@ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
alphas = alphacums[ddim_timesteps]
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
# according to the formula provided in https://arxiv.org/abs/2010.02502
# according the the formula provided in https://arxiv.org/abs/2010.02502
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
if verbose:
logging.info(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')

View File

@ -1135,7 +1135,7 @@ class AudioInjector_WAN(nn.Module):
self.injector_adain_output_layers = nn.ModuleList(
[operations.Linear(dim, dim, dtype=dtype, device=device) for _ in range(audio_injector_id)])
def forward(self, x, block_id, audio_emb, audio_emb_global, seq_len, scale=1.0):
def forward(self, x, block_id, audio_emb, audio_emb_global, seq_len):
audio_attn_id = self.injected_block_id.get(block_id, None)
if audio_attn_id is None:
return x
@ -1148,15 +1148,12 @@ class AudioInjector_WAN(nn.Module):
attn_hidden_states = adain_hidden_states
else:
attn_hidden_states = self.injector_pre_norm_feat[audio_attn_id](input_hidden_states)
if audio_emb.dim() == 3: # WanDancer case
attn_audio_emb = rearrange(audio_emb, "b t c -> (b t) 1 c", t=num_frames)
else: # S2V case
attn_audio_emb = rearrange(audio_emb, "b t n c -> (b t) n c", t=num_frames)
audio_emb = rearrange(audio_emb, "b t n c -> (b t) n c", t=num_frames)
attn_audio_emb = audio_emb
residual_out = self.injector[audio_attn_id](x=attn_hidden_states, context=attn_audio_emb)
residual_out = rearrange(residual_out, "(b t) n c -> b (t n) c", t=num_frames)
x[:, :seq_len] = x[:, :seq_len] + residual_out * scale
residual_out = rearrange(
residual_out, "(b t) n c -> b (t n) c", t=num_frames)
x[:, :seq_len] = x[:, :seq_len] + residual_out
return x

View File

@ -1,251 +0,0 @@
import torch
import torch.nn as nn
import comfy
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.math import apply_rope1
from comfy.ldm.flux.layers import EmbedND
from .model import AudioInjector_WAN, WanModel, MLPProj, Head, sinusoidal_embedding_1d
class MusicSelfAttention(nn.Module):
def __init__(self, dim, num_heads, device=None, dtype=None, operations=None):
assert dim % num_heads == 0
super().__init__()
self.embed_dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.q_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
self.k_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
self.v_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
self.out_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
def forward(self, x, freqs):
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
q = self.q_proj(x).view(b, s, n, d)
q = apply_rope1(q, freqs)
k = self.k_proj(x).view(b, s, n, d)
k = apply_rope1(k, freqs)
x = optimized_attention(
q.view(b, s, n * d),
k.view(b, s, n * d),
self.v_proj(x).view(b, s, n * d),
heads=self.num_heads,
)
return self.out_proj(x)
class MusicEncoderLayer(nn.Module):
def __init__(self, dim: int, num_heads: int, ffn_dim: int, device=None, dtype=None, operations=None):
super().__init__()
self.self_attn = MusicSelfAttention(dim, num_heads, device=device, dtype=dtype, operations=operations)
self.linear1 = operations.Linear(dim, ffn_dim, device=device, dtype=dtype)
self.linear2 = operations.Linear(ffn_dim, dim, device=device, dtype=dtype)
self.norm1 = operations.LayerNorm(dim, device=device, dtype=dtype)
self.norm2 = operations.LayerNorm(dim, device=device, dtype=dtype)
def forward(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
x = x + self.self_attn(self.norm1(x), freqs=freqs)
x = x + self.linear2(torch.nn.functional.gelu(self.linear1(self.norm2(x)))) # ffn
return x
class WanDancerModel(WanModel):
def __init__(self,
model_type='wandancer',
patch_size=(1, 2, 2),
text_len=512,
in_dim=16,
dim=5120,
ffn_dim=8192,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=16,
num_layers=40,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
in_dim_ref_conv=None,
image_model=None,
device=None, dtype=None, operations=None,
audio_inject_layers=[0, 4, 8, 12, 16, 20, 24, 27],
music_dim = 256,
music_heads = 4,
music_feature_dim = 35,
music_latent_dim = 256
):
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim,
num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model, in_dim_ref_conv=in_dim_ref_conv,
device=device, dtype=dtype, operations=operations)
self.dtype = dtype
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
self.patch_embedding_global = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=operation_settings.get("device"), dtype=torch.float32)
self.img_emb_refimage = MLPProj(1280, dim, operation_settings=operation_settings)
self.head_global = Head(dim, out_dim, patch_size, eps, operation_settings=operation_settings)
self.music_injector = AudioInjector_WAN(
dim=self.dim,
num_heads=self.num_heads,
inject_layer=audio_inject_layers,
root_net=self,
enable_adain=False,
dtype=dtype, device=device, operations=operations
)
self.music_projection = operations.Linear(music_feature_dim, music_latent_dim, device=device, dtype=dtype)
self.music_encoder = nn.ModuleList([MusicEncoderLayer(dim=music_dim, num_heads=music_heads, ffn_dim=1024, device=device, dtype=dtype, operations=operations) for _ in range(2)])
music_head_dim = music_dim // music_heads
self.music_rope_embedder = EmbedND(dim=music_head_dim, theta=10000.0, axes_dim=[music_head_dim])
def forward_orig(self, x, t, context, clip_fea=None, clip_fea_ref=None, freqs=None, audio_embed=None, fps=30, audio_inject_scale=1.0, transformer_options={}, **kwargs):
# embeddings
if int(fps + 0.5) != 30:
x = self.patch_embedding_global(x.float()).to(x.dtype)
else:
x = self.patch_embedding(x.float()).to(x.dtype)
grid_sizes = x.shape[2:]
latent_frames = grid_sizes[0]
transformer_options["grid_sizes"] = grid_sizes
x = x.flatten(2).transpose(1, 2)
seq_len = x.size(1)
# time embeddings
e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
e = e.reshape(t.shape[0], -1, e.shape[-1])
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
full_ref = None
if self.ref_conv is not None: # model has the weight, but this wasn't used in the original pipeline
full_ref = kwargs.get("reference_latent", None)
if full_ref is not None:
full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2)
x = torch.concat((full_ref, x), dim=1)
# context
context = self.text_embedding(context)
audio_emb = None
if audio_embed is not None: # encode music feature[1, frame_num, 35] -> [1, F*8, dim]
music_feature = self.music_projection(audio_embed)
music_seq_len = music_feature.shape[1]
music_ids = torch.arange(music_seq_len, device=music_feature.device, dtype=music_feature.dtype).reshape(1, -1, 1) # create 1D position IDs
music_freqs = self.music_rope_embedder(music_ids).movedim(1, 2)
# apply encoder layers
for layer in self.music_encoder:
music_feature = layer(music_feature, music_freqs)
# interpolate
audio_emb = torch.nn.functional.interpolate(music_feature.unsqueeze(1), size=(latent_frames * 8, self.dim), mode='bilinear').squeeze(1)
context_img_len = 0
if self.img_emb is not None and clip_fea is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.cat([context_clip, context], dim=1)
context_img_len += clip_fea.shape[-2]
if self.img_emb_refimage is not None and clip_fea_ref is not None:
context_clip_ref = self.img_emb_refimage(clip_fea_ref)
context = torch.cat([context_clip_ref, context], dim=1)
context_img_len += clip_fea_ref.shape[-2]
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options)
if audio_emb is not None:
x = self.music_injector(x, i, audio_emb, audio_emb_global=None, seq_len=seq_len, scale=audio_inject_scale)
# head
if int(fps + 0.5) != 30:
x = self.head_global(x, e)
else:
x = self.head(x, e)
if full_ref is not None:
x = x[:, full_ref.shape[1]:]
# unpatchify
x = self.unpatchify(x, grid_sizes)
return x
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, clip_fea_ref=None, fps=30, audio_inject_scale=1.0, **kwargs):
bs, c, t, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
t_len = t
if time_dim_concat is not None:
time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size)
x = torch.cat([x, time_dim_concat], dim=2)
t_len = x.shape[2]
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, fps=fps, transformer_options=transformer_options)
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, clip_fea_ref=clip_fea_ref, freqs=freqs, fps=fps, audio_inject_scale=audio_inject_scale, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, fps=30, device=None, dtype=None, transformer_options={}):
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
if steps_t is None:
steps_t = t_len
if steps_h is None:
steps_h = h_len
if steps_w is None:
steps_w = w_len
h_start = 0
w_start = 0
rope_options = transformer_options.get("rope_options", None)
if rope_options is not None:
t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
t_start += rope_options.get("shift_t", 0.0)
h_start += rope_options.get("shift_y", 0.0)
w_start += rope_options.get("shift_x", 0.0)
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
if int(fps + 0.5) != 30:
time_scale = 30.0 / fps # how many time units each frame represents relative to 30fps
positions_new = torch.arange(steps_t, device=device, dtype=dtype) * time_scale + t_start
total_frames_at_30fps = int(time_scale * steps_t + 0.5)
positions_new[-1] = t_start + (total_frames_at_30fps - 1)
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + positions_new.reshape(-1, 1, 1)
else:
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
freqs = self.rope_embedder(img_ids).movedim(1, 2)
return freqs

View File

@ -43,7 +43,6 @@ import comfy.ldm.lumina.model
import comfy.ldm.wan.model
import comfy.ldm.wan.model_animate
import comfy.ldm.wan.ar_model
import comfy.ldm.wan.model_wandancer
import comfy.ldm.hunyuan3d.model
import comfy.ldm.hidream.model
import comfy.ldm.chroma.model
@ -1600,30 +1599,6 @@ class WAN21_SCAIL(WAN21):
return out
class WAN22_WanDancer(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=True, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_wandancer.WanDancerModel)
self.image_to_video = image_to_video
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
audio_embed = kwargs.get("audio_embed", None)
if audio_embed is not None:
out['audio_embed'] = comfy.conds.CONDRegular(audio_embed)
clip_vision_output_ref = kwargs.get("clip_vision_output_ref", None)
if clip_vision_output_ref is not None:
out['clip_fea_ref'] = comfy.conds.CONDRegular(clip_vision_output_ref.penultimate_hidden_states)
fps = kwargs.get("fps", None)
if fps is not None:
out['fps'] = comfy.conds.CONDRegular(torch.FloatTensor([fps]))
audio_inject_scale = kwargs.get("audio_inject_scale", None)
if audio_inject_scale is not None:
out['audio_inject_scale'] = comfy.conds.CONDRegular(torch.FloatTensor([audio_inject_scale]))
return out
class Hunyuan3Dv2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)

View File

@ -572,8 +572,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["model_type"] = "animate"
elif '{}patch_embedding_pose.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "scail"
elif '{}patch_embedding_global.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "wandancer"
else:
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "i2v"

View File

@ -562,25 +562,6 @@ class disable_weight_init:
else:
return super().forward(*args, **kwargs)
class BatchNorm2d(torch.nn.BatchNorm2d, CastWeightBiasOp):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
running_mean = self.running_mean.to(device=input.device, dtype=weight.dtype) if self.running_mean is not None else None
running_var = self.running_var.to(device=input.device, dtype=weight.dtype) if self.running_var is not None else None
x = torch.nn.functional.batch_norm(input, running_mean, running_var, weight, bias, self.training, self.momentum, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
def reset_parameters(self):
return None
@ -768,9 +749,6 @@ class manual_cast(disable_weight_init):
class Conv3d(disable_weight_init.Conv3d):
comfy_cast_weights = True
class BatchNorm2d(disable_weight_init.BatchNorm2d):
comfy_cast_weights = True
class GroupNorm(disable_weight_init.GroupNorm):
comfy_cast_weights = True

View File

@ -1313,37 +1313,6 @@ class WAN21_SCAIL(WAN21_T2V):
out = model_base.WAN21_SCAIL(self, image_to_video=False, device=device)
return out
class WAN22_WanDancer(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "wandancer",
"in_dim": 36,
}
def __init__(self, unet_config):
super().__init__(unet_config)
self.memory_usage_factor = 1.8
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN22_WanDancer(self, image_to_video=True, device=device)
return out
def process_unet_state_dict(self, state_dict):
out_sd = {}
for k in list(state_dict.keys()):
# split music_encoder in_proj into q_proj, k_proj, v_proj
if "music_encoder" in k and "self_attn.in_proj" in k:
suffix = "weight" if k.endswith("weight") else "bias"
tensor = state_dict[k]
d = tensor.shape[0] // 3
prefix = k.replace(f"in_proj_{suffix}", "")
out_sd[f"{prefix}q_proj.{suffix}"] = tensor[:d]
out_sd[f"{prefix}k_proj.{suffix}"] = tensor[d:2*d]
out_sd[f"{prefix}v_proj.{suffix}"] = tensor[2*d:]
else:
out_sd[k] = state_dict[k]
return out_sd
class Hunyuan3Dv2(supported_models_base.BASE):
unet_config = {
"image_model": "hunyuan3d2",
@ -2013,7 +1982,6 @@ models = [
WAN22_Animate,
WAN21_FlowRVS,
WAN21_SCAIL,
WAN22_WanDancer,
Hunyuan3Dv2mini,
Hunyuan3Dv2,
Hunyuan3Dv2_1,

View File

@ -1390,7 +1390,7 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
k_out = "{}.weight_scale".format(layer)
if layer is not None:
layer_conf = {"format": "float8_e4m3fn"}
layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints
if full_precision_matrix_mult:
layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult
layers[layer] = layer_conf

View File

@ -17,7 +17,6 @@ if TYPE_CHECKING:
from spandrel import ImageModelDescriptor
from comfy.clip_vision import ClipVisionModel
from comfy.clip_vision import Output as ClipVisionOutput_
from comfy.bg_removal_model import BackgroundRemovalModel
from comfy.controlnet import ControlNet
from comfy.hooks import HookGroup, HookKeyframeGroup
from comfy.model_patcher import ModelPatcher
@ -615,11 +614,6 @@ class Model(ComfyTypeIO):
if TYPE_CHECKING:
Type = ModelPatcher
@comfytype(io_type="BACKGROUND_REMOVAL")
class BackgroundRemoval(ComfyTypeIO):
if TYPE_CHECKING:
Type = BackgroundRemovalModel
@comfytype(io_type="CLIP_VISION")
class ClipVision(ComfyTypeIO):
if TYPE_CHECKING:
@ -2263,7 +2257,6 @@ __all__ = [
"ModelPatch",
"ClipVision",
"ClipVisionOutput",
"BackgroundRemoval",
"AudioEncoder",
"AudioEncoderOutput",
"StyleModel",

View File

@ -1,11 +1,10 @@
from __future__ import annotations
from enum import Enum
from typing import Optional, Any
from typing import Optional, List, Dict, Any, Union
from pydantic import BaseModel, Field, RootModel
class TripoModelVersion(str, Enum):
v3_1_20260211 = 'v3.1-20260211'
v3_0_20250812 = 'v3.0-20250812'
v2_5_20250123 = 'v2.5-20250123'
v2_0_20240919 = 'v2.0-20240919'
@ -143,7 +142,7 @@ class TripoFileEmptyReference(BaseModel):
pass
class TripoFileReference(RootModel):
root: TripoFileTokenReference | TripoUrlReference | TripoObjectReference | TripoFileEmptyReference
root: Union[TripoFileTokenReference, TripoUrlReference, TripoObjectReference, TripoFileEmptyReference]
class TripoGetStsTokenRequest(BaseModel):
format: str = Field(..., description='The format of the image')
@ -184,7 +183,7 @@ class TripoImageToModelRequest(BaseModel):
class TripoMultiviewToModelRequest(BaseModel):
type: TripoTaskType = TripoTaskType.MULTIVIEW_TO_MODEL
files: list[TripoFileReference] = Field(..., description='The file references to convert to a model')
files: List[TripoFileReference] = Field(..., description='The file references to convert to a model')
model_version: Optional[TripoModelVersion] = Field(None, description='The model version to use for generation')
orthographic_projection: Optional[bool] = Field(False, description='Whether to use orthographic projection')
face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to')
@ -252,13 +251,27 @@ class TripoConvertModelRequest(BaseModel):
with_animation: Optional[bool] = Field(None, description='Whether to include animations')
pack_uv: Optional[bool] = Field(None, description='Whether to pack the UVs')
bake: Optional[bool] = Field(None, description='Whether to bake the model')
part_names: Optional[list[str]] = Field(None, description='The names of the parts to include')
part_names: Optional[List[str]] = Field(None, description='The names of the parts to include')
fbx_preset: Optional[TripoFbxPreset] = Field(None, description='The preset for the FBX export')
export_vertex_colors: Optional[bool] = Field(None, description='Whether to export the vertex colors')
export_orientation: Optional[TripoOrientation] = Field(None, description='The orientation for the export')
animate_in_place: Optional[bool] = Field(None, description='Whether to animate in place')
class TripoTaskRequest(RootModel):
root: Union[
TripoTextToModelRequest,
TripoImageToModelRequest,
TripoMultiviewToModelRequest,
TripoTextureModelRequest,
TripoRefineModelRequest,
TripoAnimatePrerigcheckRequest,
TripoAnimateRigRequest,
TripoAnimateRetargetRequest,
TripoStylizeModelRequest,
TripoConvertModelRequest
]
class TripoTaskOutput(BaseModel):
model: Optional[str] = Field(None, description='URL to the model')
base_model: Optional[str] = Field(None, description='URL to the base model')
@ -270,13 +283,12 @@ class TripoTask(BaseModel):
task_id: str = Field(..., description='The task ID')
type: Optional[str] = Field(None, description='The type of task')
status: Optional[TripoTaskStatus] = Field(None, description='The status of the task')
input: Optional[dict[str, Any]] = Field(None, description='The input parameters for the task')
input: Optional[Dict[str, Any]] = Field(None, description='The input parameters for the task')
output: Optional[TripoTaskOutput] = Field(None, description='The output of the task')
progress: Optional[int] = Field(None, description='The progress of the task', ge=0, le=100)
create_time: Optional[int] = Field(None, description='The creation time of the task')
running_left_time: Optional[int] = Field(None, description='The estimated time left for the task')
queue_position: Optional[int] = Field(None, description='The position in the queue')
consumed_credit: int | None = Field(None)
class TripoTaskResponse(BaseModel):
code: int = Field(0, description='The response code')
@ -284,7 +296,7 @@ class TripoTaskResponse(BaseModel):
class TripoGeneralResponse(BaseModel):
code: int = Field(0, description='The response code')
data: dict[str, str] = Field(..., description='The task ID data')
data: Dict[str, str] = Field(..., description='The task ID data')
class TripoBalanceData(BaseModel):
balance: float = Field(..., description='The account balance')

View File

@ -1271,7 +1271,7 @@ PRICE_BADGE_VIDEO = IO.PriceBadge(
)
def _seedance2_text_inputs(resolutions: list[str], default_ratio: str = "16:9"):
def _seedance2_text_inputs(resolutions: list[str]):
return [
IO.String.Input(
"prompt",
@ -1287,7 +1287,6 @@ def _seedance2_text_inputs(resolutions: list[str], default_ratio: str = "16:9"):
IO.Combo.Input(
"ratio",
options=["16:9", "4:3", "1:1", "3:4", "9:16", "21:9", "adaptive"],
default=default_ratio,
tooltip="Aspect ratio of the output video.",
),
IO.Int.Input(
@ -1421,14 +1420,8 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"Seedance 2.0",
_seedance2_text_inputs(["480p", "720p", "1080p"], default_ratio="adaptive"),
),
IO.DynamicCombo.Option(
"Seedance 2.0 Fast",
_seedance2_text_inputs(["480p", "720p"], default_ratio="adaptive"),
),
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs(["480p", "720p", "1080p"])),
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs(["480p", "720p"])),
],
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
),
@ -1595,9 +1588,9 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
def _seedance2_reference_inputs(resolutions: list[str], default_ratio: str = "16:9"):
def _seedance2_reference_inputs(resolutions: list[str]):
return [
*_seedance2_text_inputs(resolutions, default_ratio=default_ratio),
*_seedance2_text_inputs(resolutions),
IO.Autogrow.Input(
"reference_images",
template=IO.Autogrow.TemplateNames(
@ -1675,14 +1668,8 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"Seedance 2.0",
_seedance2_reference_inputs(["480p", "720p", "1080p"], default_ratio="adaptive"),
),
IO.DynamicCombo.Option(
"Seedance 2.0 Fast",
_seedance2_reference_inputs(["480p", "720p"], default_ratio="adaptive"),
),
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_reference_inputs(["480p", "720p", "1080p"])),
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_reference_inputs(["480p", "720p"])),
],
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
),

View File

@ -60,7 +60,6 @@ async def poll_until_finished(
],
status_extractor=lambda x: x.data.status,
progress_extractor=lambda x: x.data.progress,
price_extractor=lambda x: x.data.consumed_credit * 0.01 if x.data.consumed_credit else None,
estimated_duration=average_duration,
)
if response_poll.data.status == TripoTaskStatus.SUCCESS:
@ -114,6 +113,7 @@ class TripoTextToModelNode(IO.ComfyNode):
depends_on=IO.PriceBadgeDepends(
widgets=[
"model_version",
"style",
"texture",
"pbr",
"quad",
@ -124,17 +124,20 @@ class TripoTextToModelNode(IO.ComfyNode):
expr="""
(
$isV14 := $contains(widgets.model_version,"v1.4");
$isV3OrLater := $contains(widgets.model_version,"v3.");
$style := widgets.style;
$hasStyle := ($style != "" and $style != "none");
$withTexture := widgets.texture or widgets.pbr;
$isHdTexture := (widgets.texture_quality = "detailed");
$isDetailedGeometry := (widgets.geometry_quality = "detailed");
$credits := $isV14 ? 20 : (
($withTexture ? 20 : 10)
$baseCredits :=
$isV14 ? 20 : ($withTexture ? 20 : 10);
$credits :=
$baseCredits
+ ($hasStyle ? 5 : 0)
+ (widgets.quad ? 5 : 0)
+ ($isHdTexture ? 10 : 0)
+ (($isDetailedGeometry and $isV3OrLater) ? 20 : 0)
);
{"type":"usd","usd": $round($credits * 0.01, 2), "format": {"approximate": true}}
+ ($isDetailedGeometry ? 20 : 0);
{"type":"usd","usd": $round($credits * 0.01, 2)}
)
""",
),
@ -236,6 +239,7 @@ class TripoImageToModelNode(IO.ComfyNode):
depends_on=IO.PriceBadgeDepends(
widgets=[
"model_version",
"style",
"texture",
"pbr",
"quad",
@ -246,17 +250,20 @@ class TripoImageToModelNode(IO.ComfyNode):
expr="""
(
$isV14 := $contains(widgets.model_version,"v1.4");
$isV3OrLater := $contains(widgets.model_version,"v3.");
$style := widgets.style;
$hasStyle := ($style != "" and $style != "none");
$withTexture := widgets.texture or widgets.pbr;
$isHdTexture := (widgets.texture_quality = "detailed");
$isDetailedGeometry := (widgets.geometry_quality = "detailed");
$credits := $isV14 ? 30 : (
($withTexture ? 30 : 20)
$baseCredits :=
$isV14 ? 30 : ($withTexture ? 30 : 20);
$credits :=
$baseCredits
+ ($hasStyle ? 5 : 0)
+ (widgets.quad ? 5 : 0)
+ ($isHdTexture ? 10 : 0)
+ (($isDetailedGeometry and $isV3OrLater) ? 20 : 0)
);
{"type":"usd","usd": $round($credits * 0.01, 2), "format": {"approximate": true}}
+ ($isDetailedGeometry ? 20 : 0);
{"type":"usd","usd": $round($credits * 0.01, 2)}
)
""",
),
@ -351,7 +358,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
"texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True, advanced=True
),
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True, advanced=True),
IO.Boolean.Input("quad", default=False, optional=True, advanced=True, tooltip="This parameter is deprecated and does nothing."),
IO.Boolean.Input("quad", default=False, optional=True, advanced=True),
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True),
],
outputs=[
@ -372,6 +379,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
"model_version",
"texture",
"pbr",
"quad",
"texture_quality",
"geometry_quality",
],
@ -379,16 +387,17 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
expr="""
(
$isV14 := $contains(widgets.model_version,"v1.4");
$isV3OrLater := $contains(widgets.model_version,"v3.");
$withTexture := widgets.texture or widgets.pbr;
$isHdTexture := (widgets.texture_quality = "detailed");
$isDetailedGeometry := (widgets.geometry_quality = "detailed");
$credits := $isV14 ? 30 : (
($withTexture ? 30 : 20)
$baseCredits :=
$isV14 ? 30 : ($withTexture ? 30 : 20);
$credits :=
$baseCredits
+ (widgets.quad ? 5 : 0)
+ ($isHdTexture ? 10 : 0)
+ (($isDetailedGeometry and $isV3OrLater) ? 20 : 0)
);
{"type":"usd","usd": $round($credits * 0.01, 2), "format": {"approximate": true}}
+ ($isDetailedGeometry ? 20 : 0);
{"type":"usd","usd": $round($credits * 0.01, 2)}
)
""",
),
@ -448,7 +457,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
geometry_quality=geometry_quality,
texture_alignment=texture_alignment,
face_limit=face_limit if face_limit != -1 else None,
quad=None,
quad=quad,
),
)
return await poll_until_finished(cls, response, average_duration=80)
@ -489,7 +498,7 @@ class TripoTextureNode(IO.ComfyNode):
expr="""
(
$tq := widgets.texture_quality;
{"type":"usd","usd": ($contains($tq,"detailed") ? 0.2 : 0.1), "format": {"approximate": true}}
{"type":"usd","usd": ($contains($tq,"detailed") ? 0.2 : 0.1)}
)
""",
),
@ -546,7 +555,7 @@ class TripoRefineNode(IO.ComfyNode):
is_api_node=True,
is_output_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.3, "format": {"approximate": true}}""",
expr="""{"type":"usd","usd":0.3}""",
),
)
@ -583,7 +592,7 @@ class TripoRigNode(IO.ComfyNode):
is_api_node=True,
is_output_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.25, "format": {"approximate": true}}""",
expr="""{"type":"usd","usd":0.25}""",
),
)
@ -643,7 +652,7 @@ class TripoRetargetNode(IO.ComfyNode):
is_api_node=True,
is_output_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.1, "format": {"approximate": true}}""",
expr="""{"type":"usd","usd":0.1}""",
),
)
@ -752,10 +761,19 @@ class TripoConversionNode(IO.ComfyNode):
"face_limit",
"texture_size",
"texture_format",
"force_symmetry",
"flatten_bottom",
"flatten_bottom_threshold",
"pivot_to_center_bottom",
"scale_factor",
"with_animation",
"pack_uv",
"bake",
"part_names",
"fbx_preset",
"export_vertex_colors",
"export_orientation",
"animate_in_place",
],
),
expr="""
@ -765,16 +783,28 @@ class TripoConversionNode(IO.ComfyNode):
$flatThresh := (widgets.flatten_bottom_threshold != null) ? widgets.flatten_bottom_threshold : 0;
$scale := (widgets.scale_factor != null) ? widgets.scale_factor : 1;
$texFmt := (widgets.texture_format != "" ? widgets.texture_format : "jpeg");
$part := widgets.part_names;
$fbx := (widgets.fbx_preset != "" ? widgets.fbx_preset : "blender");
$orient := (widgets.export_orientation != "" ? widgets.export_orientation : "default");
$advanced :=
widgets.quad or
widgets.force_symmetry or
widgets.flatten_bottom or
widgets.pivot_to_center_bottom or
widgets.with_animation or
widgets.pack_uv or
widgets.bake or
widgets.export_vertex_colors or
widgets.animate_in_place or
($face != -1) or
($texSize != 4096) or
($flatThresh != 0) or
($scale != 1) or
($texFmt != "jpeg");
{"type":"usd","usd": ($advanced ? 0.1 : 0.05), "format": {"approximate": true}}
($texFmt != "jpeg") or
($part != "") or
($fbx != "blender") or
($orient != "default");
{"type":"usd","usd": ($advanced ? 0.1 : 0.05)}
)
""",
),

View File

@ -488,30 +488,10 @@ async def _diagnose_connectivity() -> dict[str, bool]:
"api_accessible": False,
}
timeout = aiohttp.ClientTimeout(total=5.0)
# Probe Google and Baidu in parallel: Google is blocked by the GFW in mainland China, so a Baidu probe is required
# to correctly detect that Chinese users with working internet do have working internet.
internet_probe_urls = ("https://www.google.com", "https://www.baidu.com")
async with aiohttp.ClientSession(timeout=timeout) as session:
async def _probe(url: str) -> bool:
try:
async with session.get(url) as resp:
return resp.status < 500
except (ClientError, OSError, asyncio.TimeoutError):
return False
probe_tasks = [asyncio.create_task(_probe(u)) for u in internet_probe_urls]
try:
for fut in asyncio.as_completed(probe_tasks):
if await fut:
results["internet_accessible"] = True
break
finally:
for t in probe_tasks:
if not t.done():
t.cancel()
await asyncio.gather(*probe_tasks, return_exceptions=True)
with contextlib.suppress(ClientError, OSError):
async with session.get("https://www.google.com") as resp:
results["internet_accessible"] = resp.status < 500
if not results["internet_accessible"]:
return results

View File

@ -1,60 +0,0 @@
import folder_paths
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO
from comfy.bg_removal_model import load
class LoadBackgroundRemovalModel(IO.ComfyNode):
@classmethod
def define_schema(cls):
files = folder_paths.get_filename_list("background_removal")
return IO.Schema(
node_id="LoadBackgroundRemovalModel",
display_name="Load Background Removal Model",
category="loaders",
inputs=[
IO.Combo.Input("bg_removal_name", options=sorted(files), tooltip="The model used to remove backgrounds from images"),
],
outputs=[
IO.BackgroundRemoval.Output("bg_model")
]
)
@classmethod
def execute(cls, bg_removal_name):
path = folder_paths.get_full_path_or_raise("background_removal", bg_removal_name)
bg = load(path)
if bg is None:
raise RuntimeError("ERROR: background model file is invalid and does not contain a valid background removal model.")
return IO.NodeOutput(bg)
class RemoveBackground(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="RemoveBackground",
display_name="Remove Background",
category="image/background removal",
inputs=[
IO.Image.Input("image", tooltip="Input image to remove the background from"),
IO.BackgroundRemoval.Input("bg_removal_model", tooltip="Background removal model used to generate the mask")
],
outputs=[
IO.Mask.Output("mask", tooltip="Generated foreground mask")
]
)
@classmethod
def execute(cls, image, bg_removal_model):
mask = bg_removal_model.encode_image(image)
return IO.NodeOutput(mask)
class BackgroundRemovalExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
LoadBackgroundRemovalModel,
RemoveBackground
]
async def comfy_entrypoint() -> BackgroundRemovalExtension:
return BackgroundRemovalExtension()

View File

@ -203,7 +203,7 @@ class JoinImageWithAlpha(io.ComfyNode):
@classmethod
def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput:
batch_size = max(len(image), len(alpha))
alpha = 1.0 - resize_mask(alpha.to(image), image.shape[1:])
alpha = 1.0 - resize_mask(alpha, image.shape[1:])
alpha = comfy.utils.repeat_to_batch_size(alpha, batch_size)
image = comfy.utils.repeat_to_batch_size(image, batch_size)
return io.NodeOutput(torch.cat((image[..., :3], alpha.unsqueeze(-1)), dim=-1))

View File

@ -102,7 +102,7 @@ class FluxDisableGuidance(io.ComfyNode):
append = execute # TODO: remove
PREFERRED_KONTEXT_RESOLUTIONS = [
PREFERED_KONTEXT_RESOLUTIONS = [
(672, 1568),
(688, 1504),
(720, 1456),
@ -143,7 +143,7 @@ class FluxKontextImageScale(io.ComfyNode):
width = image.shape[2]
height = image.shape[1]
aspect_ratio = width / height
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS)
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1)
return io.NodeOutput(image)

View File

@ -106,12 +106,12 @@ class LTXVImgToVideoInplace(io.ComfyNode):
if bypass:
return (latent,)
samples = latent["samples"].clone()
samples = latent["samples"]
_, height_scale_factor, width_scale_factor = (
vae.downscale_index_formula
)
_, _, _, latent_height, latent_width = samples.shape
batch, _, latent_frames, latent_height, latent_width = samples.shape
width = latent_width * width_scale_factor
height = latent_height * height_scale_factor
@ -124,7 +124,11 @@ class LTXVImgToVideoInplace(io.ComfyNode):
samples[:, :, :t.shape[2]] = t
conditioning_latent_frames_mask = get_noise_mask(latent)
conditioning_latent_frames_mask = torch.ones(
(batch, 1, latent_frames, 1, 1),
dtype=torch.float32,
device=samples.device,
)
conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength
return io.NodeOutput({"samples": samples, "noise_mask": conditioning_latent_frames_mask})
@ -232,7 +236,7 @@ class LTXVAddGuide(io.ComfyNode):
def encode(cls, vae, latent_width, latent_height, images, scale_factors):
time_scale_factor, width_scale_factor, height_scale_factor = scale_factors
images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1]
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="center").movedim(1, -1)
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="disabled").movedim(1, -1)
encode_pixels = pixels[:, :, :, :3]
t = vae.encode(encode_pixels)
return encode_pixels, t

View File

@ -40,21 +40,10 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
inverse_mask = torch.ones_like(mask) - mask
source_rgb = source[:, :3, :visible_height, :visible_width]
dest_slice = destination[..., top:bottom, left:right]
if destination.shape[1] == 4:
if torch.max(dest_slice) == 0:
destination[:, :3, top:bottom, left:right] = source_rgb
destination[:, 3:4, top:bottom, left:right] = mask
else:
destination[:, :3, top:bottom, left:right] = (mask * source_rgb) + (inverse_mask * dest_slice[:, :3])
destination[:, 3:4, top:bottom, left:right] = torch.max(mask, dest_slice[:, 3:4])
else:
source_portion = mask * source_rgb
destination_portion = inverse_mask * dest_slice
destination[..., top:bottom, left:right] = source_portion + destination_portion
source_portion = mask * source[..., :visible_height, :visible_width]
destination_portion = inverse_mask * destination[..., top:bottom, left:right]
destination[..., top:bottom, left:right] = source_portion + destination_portion
return destination
class LatentCompositeMasked(IO.ComfyNode):
@ -95,23 +84,18 @@ class ImageCompositeMasked(IO.ComfyNode):
display_name="Image Composite Masked",
category="image",
inputs=[
IO.Image.Input("destination"),
IO.Image.Input("source"),
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
IO.Boolean.Input("resize_source", default=False),
IO.Image.Input("destination", optional=True),
IO.Mask.Input("mask", optional=True),
],
outputs=[IO.Image.Output()],
)
@classmethod
def execute(cls, source, x, y, resize_source, destination = None, mask = None) -> IO.NodeOutput:
if destination is None: # transparent rgba
B, H, W, C = source.shape
destination = torch.zeros((B, H, W, 4), dtype=source.dtype, device=source.device)
if C == 3:
source = torch.nn.functional.pad(source, (0, 1), value=1.0)
def execute(cls, destination, source, x, y, resize_source, mask = None) -> IO.NodeOutput:
destination, source = node_helpers.image_alpha_fix(destination, source)
destination = destination.clone().movedim(-1, 1)
output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1)
@ -397,6 +381,7 @@ class GrowMask(IO.ComfyNode):
expand_mask = execute # TODO: remove
class ThresholdMask(IO.ComfyNode):
@classmethod
def define_schema(cls):

View File

@ -63,7 +63,7 @@ class MathExpressionNode(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
autogrow = io.Autogrow.TemplateNames(
input=io.MultiType.Input("value", [io.Float, io.Int, io.Boolean]),
input=io.MultiType.Input("value", [io.Float, io.Int]),
names=list(string.ascii_lowercase),
min=1,
)
@ -82,7 +82,6 @@ class MathExpressionNode(io.ComfyNode):
outputs=[
io.Float.Output(display_name="FLOAT"),
io.Int.Output(display_name="INT"),
io.Boolean.Output(display_name="BOOL"),
],
)
@ -98,7 +97,7 @@ class MathExpressionNode(io.ComfyNode):
result = simple_eval(expression, names=context, functions=MATH_FUNCTIONS)
# bool check must come first because bool is a subclass of int in Python
if not isinstance(result, (int, float)):
if isinstance(result, bool) or not isinstance(result, (int, float)):
raise ValueError(
f"Math Expression '{expression}' must evaluate to a numeric result, "
f"got {type(result).__name__}: {result!r}"
@ -107,7 +106,7 @@ class MathExpressionNode(io.ComfyNode):
raise ValueError(
f"Math Expression '{expression}' produced a non-finite result: {result}"
)
return io.NodeOutput(float(result), int(result), bool(result))
return io.NodeOutput(float(result), int(result))
class MathExtension(ComfyExtension):

View File

@ -1,971 +0,0 @@
import math
import nodes
import node_helpers
import torch
import torchaudio
import comfy.model_management
import comfy.utils
import numpy as np
import logging
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
import scipy.signal
import scipy.ndimage
import scipy.fft
import scipy.sparse
# Audio Processing Functions - Derived from librosa (https://github.com/librosa/librosa)
# Copyright (c) 2013--2023, librosa development team.
def mel_to_hz(mels, htk=False):
"""Convert mel to Hz (slaney)"""
mels = np.asanyarray(mels)
if htk:
return 700.0 * (10.0 ** (mels / 2595.0) - 1.0)
f_min = 0.0
f_sp = 200.0 / 3
freqs = f_min + f_sp * mels
min_log_hz = 1000.0
min_log_mel = (min_log_hz - f_min) / f_sp
logstep = np.log(6.4) / 27.0
if mels.ndim:
log_t = mels >= min_log_mel
freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel))
elif mels >= min_log_mel:
freqs = min_log_hz * np.exp(logstep * (mels - min_log_mel))
return freqs
def hz_to_mel(frequencies, htk=False):
"""Convert Hz to mel (slaney)"""
frequencies = np.asanyarray(frequencies)
if htk:
return 2595.0 * np.log10(1.0 + frequencies / 700.0)
f_min = 0.0
f_sp = 200.0 / 3
mels = (frequencies - f_min) / f_sp
min_log_hz = 1000.0
min_log_mel = (min_log_hz - f_min) / f_sp
logstep = np.log(6.4) / 27.0
if frequencies.ndim:
log_t = frequencies >= min_log_hz
mels[log_t] = min_log_mel + np.log(frequencies[log_t] / min_log_hz) / logstep
elif frequencies >= min_log_hz:
mels = min_log_mel + np.log(frequencies / min_log_hz) / logstep
return mels
def compute_cqt(y, sr=22050, hop_length=512, fmin=None, n_bins=84, bins_per_octave=12, tuning=0.0):
"""Compute Constant-Q Transform (CQT) spectrogram."""
def _relative_bandwidth(freqs):
bpo = np.empty_like(freqs)
logf = np.log2(freqs)
bpo[0] = 1.0 / (logf[1] - logf[0])
bpo[-1] = 1.0 / (logf[-1] - logf[-2])
bpo[1:-1] = 2.0 / (logf[2:] - logf[:-2])
return (2.0 ** (2.0 / bpo) - 1.0) / (2.0 ** (2.0 / bpo) + 1.0)
def _wavelet_lengths(freqs, sr, filter_scale, alpha):
Q = float(filter_scale) / alpha
return Q * sr / freqs # shape (n_bins,) floats
def _build_wavelet(freqs_oct, sr, filter_scale, alpha_oct):
lengths = _wavelet_lengths(freqs_oct, sr, filter_scale, alpha_oct)
filters = []
for ilen, freq in zip(lengths, freqs_oct):
t = np.arange(int(-ilen // 2), int(ilen // 2), dtype=float)
sig = (np.cos(t * 2 * np.pi * freq / sr)
+ 1j * np.sin(t * 2 * np.pi * freq / sr)).astype(np.complex64)
sig *= scipy.signal.get_window('hann', len(sig), fftbins=True)
l1 = np.sum(np.abs(sig))
tiny = np.finfo(np.float32).tiny
sig /= max(l1, tiny)
filters.append(sig)
max_len = max(lengths)
n_fft = int(2.0 ** np.ceil(np.log2(max_len)))
out = np.zeros((len(filters), n_fft), dtype=np.complex64)
for k, f in enumerate(filters):
lpad = int((n_fft - len(f)) // 2)
out[k, lpad: lpad + len(f)] = f
return out, lengths
def _resample_half(y):
ratio = 0.5
n_samples = int(np.ceil(len(y) * ratio))
# Kaiser-windowed FIR matches librosa/soxr more closely than scipy's default Hamming filter
L = 2
h = scipy.signal.firwin(160 * L + 1, 0.96 / L, window=('kaiser', 6.5))
y_hat = scipy.signal.resample_poly(y.astype(np.float32), 1, 2, window=h)
if len(y_hat) > n_samples:
y_hat = y_hat[:n_samples]
elif len(y_hat) < n_samples:
y_hat = np.pad(y_hat, (0, n_samples - len(y_hat)))
y_hat /= np.sqrt(ratio)
return y_hat.astype(np.float32)
def _sparsify_rows(x, quantile=0.01):
mags = np.abs(x)
norms = np.sum(mags, axis=1, keepdims=True)
norms = np.where(norms == 0, 1.0, norms)
mag_sort = np.sort(mags, axis=1)
cumulative_mag = np.cumsum(mag_sort / norms, axis=1)
threshold_idx = np.argmin(cumulative_mag < quantile, axis=1)
x_sparse = scipy.sparse.lil_matrix(x.shape, dtype=x.dtype)
for i, j in enumerate(threshold_idx):
idx = np.where(mags[i] >= mag_sort[i, j])
x_sparse[i, idx] = x[i, idx]
return x_sparse.tocsr()
if fmin is None:
fmin = 32.70319566257483 # C1 note frequency
fmin = fmin * (2.0 ** (tuning / bins_per_octave))
freqs = fmin * (2.0 ** (np.arange(n_bins) / bins_per_octave))
alpha = _relative_bandwidth(freqs)
lengths = _wavelet_lengths(freqs, float(sr), 1, alpha)
n_octaves = int(np.ceil(float(n_bins) / bins_per_octave))
n_filters = min(bins_per_octave, n_bins)
cqt_resp = []
my_y = y.astype(np.float32)
my_sr = float(sr)
my_hop = int(hop_length)
for i in range(n_octaves):
if i == 0:
sl = slice(-n_filters, None)
else:
sl = slice(-n_filters * (i + 1), -n_filters * i)
freqs_oct = freqs[sl]
alpha_oct = alpha[sl]
basis, basis_lengths = _build_wavelet(freqs_oct, my_sr, 1, alpha_oct)
n_fft_oct = basis.shape[1]
# Frequency-domain normalisation
basis = basis.astype(np.complex64)
basis *= basis_lengths[:, np.newaxis] / float(n_fft_oct)
fft_basis = scipy.fft.fft(basis, n=n_fft_oct, axis=1)[:, :(n_fft_oct // 2) + 1]
fft_basis = _sparsify_rows(fft_basis, quantile=0.01)
fft_basis = fft_basis * np.sqrt(sr / my_sr)
y_pad = np.pad(my_y, int(n_fft_oct // 2), mode='constant')
n_frames = 1 + (len(y_pad) - n_fft_oct) // my_hop
frames = np.lib.stride_tricks.as_strided(
y_pad,
shape=(n_fft_oct, n_frames),
strides=(y_pad.strides[0], y_pad.strides[0] * my_hop),
)
stft_result = scipy.fft.rfft(frames, axis=0)
cqt_resp.append(fft_basis.dot(stft_result))
if my_hop % 2 == 0:
my_hop //= 2
my_sr /= 2.0
my_y = _resample_half(my_y)
max_col = min(c.shape[-1] for c in cqt_resp)
cqt_out = np.empty((n_bins, max_col), dtype=np.complex64)
end = n_bins
for c_i in cqt_resp:
n_oct = c_i.shape[0]
if end < n_oct:
cqt_out[:end, :] = c_i[-end:, :max_col]
else:
cqt_out[end - n_oct:end, :] = c_i[:, :max_col]
end -= n_oct
cqt_out /= np.sqrt(lengths)[:, np.newaxis]
return np.abs(cqt_out).astype(np.float32)
def cq_to_chroma_mapping(n_input, bins_per_octave=12, n_chroma=12, fmin=None):
"""Map CQT bins to chroma bins."""
if fmin is None:
fmin = 32.70319566257483 # C1 note frequency
n_merge = bins_per_octave / n_chroma
cq_to_ch = np.repeat(np.eye(n_chroma), int(n_merge), axis=1)
cq_to_ch = np.roll(cq_to_ch, -int(n_merge // 2), axis=1)
n_octaves = int(np.ceil(n_input / bins_per_octave))
cq_to_ch = np.tile(cq_to_ch, n_octaves)[:, :n_input]
midi_0 = np.mod(12 * np.log2(fmin / 440.0) + 69, 12)
roll = int(np.round(midi_0 * (n_chroma / 12.0)))
cq_to_ch = np.roll(cq_to_ch, roll, axis=0)
return cq_to_ch.astype(np.float32)
def _parabolic_interpolation(S, axis=-2):
"""Compute parabolic interpolation shift for peak refinement."""
S_next = np.roll(S, -1, axis=axis)
S_prev = np.roll(S, 1, axis=axis)
a = S_next + S_prev - 2 * S
b = (S_next - S_prev) / 2.0
shifts = np.zeros_like(S)
valid = np.abs(b) < np.abs(a)
shifts[valid] = -b[valid] / a[valid]
if axis == -2 or axis == S.ndim - 2:
shifts[0, :] = 0
shifts[-1, :] = 0
elif axis == 0:
shifts[0, ...] = 0
shifts[-1, ...] = 0
return shifts
def _localmax(S, axis=-2):
"""Find local maxima along an axis."""
S_prev = np.roll(S, 1, axis=axis)
S_next = np.roll(S, -1, axis=axis)
local_max = (S > S_prev) & (S >= S_next)
if axis == -2 or axis == S.ndim - 2:
local_max[-1, :] = S[-1, :] > S[-2, :]
# First element is never a local max (strict inequality with previous)
local_max[0, :] = False
elif axis == 0:
local_max[-1, ...] = S[-1, ...] > S[-2, ...]
local_max[0, ...] = False
return local_max
def piptrack(y=None, sr=22050, S=None, n_fft=2048, hop_length=512,
fmin=150.0, fmax=4000.0, threshold=0.1):
"""Pitch tracking on thresholded parabolically-interpolated STFT."""
# Compute STFT if not provided
if S is None:
if y is None:
raise ValueError("Either y or S must be provided")
fft_window = scipy.signal.get_window('hann', n_fft, fftbins=True)
if len(fft_window) < n_fft:
lpad = int((n_fft - len(fft_window)) // 2)
fft_window = np.pad(fft_window, (lpad, int(n_fft - len(fft_window) - lpad)), mode='constant')
fft_window = fft_window.reshape((-1, 1))
y_pad = np.pad(y, int(n_fft // 2), mode='constant')
n_frames = 1 + (len(y_pad) - n_fft) // hop_length
frames = np.lib.stride_tricks.as_strided(
y_pad,
shape=(n_fft, n_frames),
strides=(y_pad.strides[0], y_pad.strides[0] * hop_length)
)
S = scipy.fft.rfft((fft_window * frames).astype(np.float32), axis=0)
S = np.abs(S)
fmin = max(fmin, 0)
fmax = min(fmax, float(sr) / 2)
fft_freqs = np.fft.rfftfreq(S.shape[0] * 2 - 2, 1.0 / sr)
if len(fft_freqs) > S.shape[0]:
fft_freqs = fft_freqs[:S.shape[0]]
shift = _parabolic_interpolation(S, axis=0)
avg = np.gradient(S, axis=0)
dskew = 0.5 * avg * shift
pitches = np.zeros_like(S)
mags = np.zeros_like(S)
freq_mask = (fmin <= fft_freqs) & (fft_freqs < fmax)
freq_mask = freq_mask.reshape(-1, 1)
ref_value = threshold * np.max(S, axis=0, keepdims=True)
local_max = _localmax(S * (S > ref_value), axis=0)
idx = np.nonzero(freq_mask & local_max)
pitches[idx] = (idx[0] + shift[idx]) * float(sr) / (S.shape[0] * 2 - 2)
mags[idx] = S[idx] + dskew[idx]
return pitches, mags
def hz_to_octs(frequencies, tuning=0.0, bins_per_octave=12):
"""Convert frequencies (Hz) to octave numbers."""
A440 = 440.0 * 2.0 ** (tuning / bins_per_octave)
octs = np.log2(np.asanyarray(frequencies) / (float(A440) / 16))
return octs
def pitch_tuning(frequencies, resolution=0.01, bins_per_octave=12):
"""Estimate tuning offset from a collection of pitches."""
frequencies = np.atleast_1d(frequencies)
frequencies = frequencies[frequencies > 0]
if not np.any(frequencies):
return 0.0
residual = np.mod(bins_per_octave * hz_to_octs(frequencies, tuning=0.0,
bins_per_octave=bins_per_octave), 1.0)
residual[residual >= 0.5] -= 1.0
bins = np.linspace(-0.5, 0.5, int(np.ceil(1.0 / resolution)) + 1)
counts, tuning = np.histogram(residual, bins)
tuning_est = tuning[np.argmax(counts)]
return tuning_est
def estimate_tuning(y, sr=22050, bins_per_octave=12):
"""Estimate global tuning deviation from 12-TET."""
n_fft = 2048
hop_length = 512
if len(y) < n_fft:
return 0.0
pitch, mag = piptrack(y=y, sr=sr, n_fft=n_fft, hop_length=hop_length,
fmin=150.0, fmax=4000.0, threshold=0.1)
pitch_mask = pitch > 0
if not pitch_mask.any():
return 0.0
threshold = np.median(mag[pitch_mask])
valid_pitches = pitch[(mag >= threshold) & pitch_mask]
if len(valid_pitches) == 0:
return 0.0
tuning = pitch_tuning(valid_pitches, resolution=0.01, bins_per_octave=bins_per_octave)
return float(tuning)
def compute_chroma_cens(y, sr=22050, hop_length=512, n_chroma=12,
n_octaves=7, bins_per_octave=36,
win_len_smooth=41, norm=2):
"""Compute Chroma Energy Normalized Statistics (CENS) features."""
tuning = estimate_tuning(y, sr, bins_per_octave=bins_per_octave)
fmin = 32.70319566257483 # C1 note frequency
n_bins = n_octaves * bins_per_octave
cqt_mag = compute_cqt(y, sr=sr, hop_length=hop_length,
fmin=fmin, n_bins=n_bins,
bins_per_octave=bins_per_octave,
tuning=tuning)
chroma_map = cq_to_chroma_mapping(n_bins, bins_per_octave=bins_per_octave,
n_chroma=n_chroma, fmin=fmin)
chroma = np.dot(chroma_map, cqt_mag)
threshold = np.finfo(chroma.dtype).tiny
chroma_sum = np.sum(np.abs(chroma), axis=0, keepdims=True)
chroma_sum = np.maximum(chroma_sum, threshold)
chroma = chroma / chroma_sum
quant_steps = [0.4, 0.2, 0.1, 0.05]
quant_weights = [0.25, 0.25, 0.25, 0.25]
chroma_quant = np.zeros_like(chroma)
for step, weight in zip(quant_steps, quant_weights):
chroma_quant += (chroma > step) * weight
if win_len_smooth is not None and win_len_smooth > 0:
win = scipy.signal.get_window('hann', win_len_smooth + 2, fftbins=False)
win /= np.sum(win)
win = win.reshape(1, -1)
chroma_smooth = scipy.ndimage.convolve(chroma_quant, win, mode='constant')
else:
chroma_smooth = chroma_quant
if norm == 2:
threshold = np.finfo(chroma_smooth.dtype).tiny
chroma_norm = np.sqrt(np.sum(chroma_smooth ** 2, axis=0, keepdims=True))
chroma_norm = np.maximum(chroma_norm, threshold)
chroma_smooth = chroma_smooth / chroma_norm
elif norm == np.inf:
threshold = np.finfo(chroma_smooth.dtype).tiny
chroma_norm = np.max(np.abs(chroma_smooth), axis=0, keepdims=True)
chroma_norm = np.maximum(chroma_norm, threshold)
chroma_smooth = chroma_smooth / chroma_norm
return chroma_smooth
def _create_mel_filterbank(sr, n_fft, n_mels=128, fmin=0.0, fmax=None):
"""Create mel-scale filterbank matrix."""
if fmax is None:
fmax = sr / 2.0
mel_basis = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=np.float32)
fftfreqs = np.fft.rfftfreq(n=n_fft, d=1.0 / sr)
min_mel = hz_to_mel(fmin)
max_mel = hz_to_mel(fmax)
mels = np.linspace(min_mel, max_mel, n_mels + 2)
mel_f = mel_to_hz(mels)
fdiff = np.diff(mel_f)
ramps = np.subtract.outer(mel_f, fftfreqs)
for i in range(n_mels):
lower = -ramps[i] / fdiff[i]
upper = ramps[i + 2] / fdiff[i + 1]
mel_basis[i] = np.maximum(0, np.minimum(lower, upper))
enorm = 2.0 / (mel_f[2:n_mels + 2] - mel_f[:n_mels])
mel_basis *= enorm[:, np.newaxis]
return mel_basis
def _compute_mel_spectrogram(data, sr, n_fft=2048, hop_length=512, n_mels=128):
"""Compute mel spectrogram from audio signal."""
fft_window = scipy.signal.get_window('hann', n_fft, fftbins=True)
if len(fft_window) < n_fft:
lpad = int((n_fft - len(fft_window)) // 2)
fft_window = np.pad(fft_window, (lpad, int(n_fft - len(fft_window) - lpad)), mode='constant')
fft_window = fft_window.reshape((-1, 1))
data_padded = np.pad(data, int(n_fft // 2), mode='constant')
n_frames = 1 + (len(data_padded) - n_fft) // hop_length
shape = (n_fft, n_frames)
strides = (data_padded.strides[0], data_padded.strides[0] * hop_length)
frames = np.lib.stride_tricks.as_strided(data_padded, shape=shape, strides=strides)
stft_result = scipy.fft.rfft(fft_window * frames, axis=0).astype(np.complex64)
power_spec = np.abs(stft_result) ** 2
mel_basis = _create_mel_filterbank(sr, n_fft, n_mels=n_mels, fmin=0.0, fmax=sr / 2.0)
mel_spec = np.dot(mel_basis, power_spec)
return mel_spec.astype(np.float32)
def quick_tempo_estimate(audio_np, sr, start_bpm=120.0, std_bpm=1.0, hop_length=512):
"""Estimate tempo using autocorrelation tempogram."""
if len(audio_np) < hop_length * 10:
logging.warning("Audio too short for tempo estimation, returning default BPM of 120.0")
return 120.0
n_fft = 2048
mel_S = _compute_mel_spectrogram(audio_np, sr, n_fft=n_fft, hop_length=hop_length, n_mels=128)
log_mel_S = 10.0 * np.log10(np.maximum(1e-10, mel_S))
lag = 1
S_diff = log_mel_S[:, lag:] - log_mel_S[:, :-lag]
S_onset = np.maximum(0.0, S_diff)
onset_env_pre = np.mean(S_onset, axis=0)
pad_width = lag + n_fft // (2 * hop_length)
onset_env = np.pad(onset_env_pre, (pad_width, 0), mode='constant')
onset_env = onset_env[:mel_S.shape[1]]
return estimate_tempo_from_onset(onset_env, sr, hop_length, start_bpm, std_bpm, max_tempo=320.0)
def estimate_tempo_from_onset(onset_env, sr, hop_length, start_bpm=120.0, std_bpm=1.0, max_tempo=320.0):
"""Estimate tempo from onset strength envelope using autocorrelation tempogram."""
if len(onset_env) < 20:
return 120.0
ac_size = 8.0
win_length = int(np.round(ac_size * sr / hop_length))
win_length = min(win_length, len(onset_env))
pad_width = win_length // 2
onset_padded = np.pad(onset_env, (pad_width, pad_width), mode='linear_ramp', end_values=(0, 0))
n_frames = len(onset_env)
shape = (win_length, n_frames)
strides = (onset_padded.strides[0], onset_padded.strides[0])
frames = np.lib.stride_tricks.as_strided(onset_padded, shape=shape, strides=strides)
hann_window = scipy.signal.get_window('hann', win_length, fftbins=True)
windowed_frames = frames * hann_window[:, np.newaxis]
tempogram = np.zeros((win_length, n_frames))
for i in range(n_frames):
frame = windowed_frames[:, i]
n_pad = scipy.fft.next_fast_len(2 * len(frame) - 1)
fft_result = scipy.fft.rfft(frame, n=n_pad)
powspec = np.abs(fft_result) ** 2
ac = scipy.fft.irfft(powspec, n=n_pad)
tempogram[:, i] = ac[:win_length]
ac_max = np.max(np.abs(tempogram), axis=0)
mask = ac_max > 0
tempogram[:, mask] /= ac_max[mask]
tempogram_mean = np.mean(tempogram, axis=1)
tempogram_mean = np.maximum(tempogram_mean, 0)
bpms = np.zeros(win_length, dtype=np.float64)
bpms[0] = np.inf
bpms[1:] = 60.0 * sr / (hop_length * np.arange(1.0, win_length))
logprior = -0.5 * ((np.log2(bpms) - np.log2(start_bpm)) / std_bpm) ** 2
if max_tempo is not None:
max_idx = int(np.argmax(bpms < max_tempo))
if max_idx > 0:
logprior[:max_idx] = -np.inf
weighted = np.log1p(1e6 * tempogram_mean) + logprior
best_idx = int(np.argmax(weighted[1:])) + 1
tempo = bpms[best_idx]
return tempo
def detect_onset_peaks(onset_env, sr=22050, hop_length=512, pre_max=0.03, post_max=0.0,
pre_avg=0.10, post_avg=0.10, wait=0.03, delta=0.07):
"""Detect onset peaks using peak picking algorithm."""
onset_normalized = onset_env - np.min(onset_env)
onset_max = np.max(onset_normalized)
if onset_max > 0:
onset_normalized = onset_normalized / onset_max
pre_max_frames = int(pre_max * sr / hop_length)
post_max_frames = int(post_max * sr / hop_length) + 1
pre_avg_frames = int(pre_avg * sr / hop_length)
post_avg_frames = int(post_avg * sr / hop_length) + 1
wait_frames = int(wait * sr / hop_length)
peaks = np.zeros(len(onset_normalized), dtype=bool)
peaks[0] = (onset_normalized[0] >= np.max(onset_normalized[:min(post_max_frames, len(onset_normalized))]))
peaks[0] &= (onset_normalized[0] >= np.mean(onset_normalized[:min(post_avg_frames, len(onset_normalized))]) + delta)
if peaks[0]:
n = wait_frames + 1
else:
n = 1
while n < len(onset_normalized):
maxn = np.max(onset_normalized[max(0, n - pre_max_frames):min(n + post_max_frames, len(onset_normalized))])
peaks[n] = (onset_normalized[n] == maxn)
if not peaks[n]:
n += 1
continue
avgn = np.mean(onset_normalized[max(0, n - pre_avg_frames):min(n + post_avg_frames, len(onset_normalized))])
peaks[n] &= (onset_normalized[n] >= avgn + delta)
if not peaks[n]:
n += 1
continue
n += wait_frames + 1
return np.flatnonzero(peaks).astype(np.int32)
def track_beats(onset_env, tempo, sr, hop_length, tightness=100, trim=True):
"""Track beats using dynamic programming."""
frame_rate = sr / hop_length
frames_per_beat = np.round(frame_rate * 60.0 / tempo)
if frames_per_beat <= 0 or len(onset_env) < 2:
return np.array([], dtype=np.int32)
onset_std = np.std(onset_env, ddof=1)
if onset_std > 0:
onset_normalized = onset_env / onset_std
else:
onset_normalized = onset_env
window_range = np.arange(-frames_per_beat, frames_per_beat + 1)
window = np.exp(-0.5 * (window_range * 32.0 / frames_per_beat) ** 2)
localscore = scipy.signal.convolve(onset_normalized, window, mode='same')
backlink = np.full(len(localscore), -1, dtype=np.int32)
cumscore = np.zeros(len(localscore), dtype=np.float64)
score_thresh = 0.01 * localscore.max()
first_beat = True
backlink[0] = -1
cumscore[0] = localscore[0]
fpb = int(frames_per_beat)
for i in range(1, len(localscore)):
score_i = localscore[i]
best_score = -np.inf
beat_location = -1
search_start = int(i - np.round(fpb / 2.0))
search_end = int(i - 2 * fpb - 1)
for loc in range(search_start, search_end, -1):
if loc < 0:
break
score = cumscore[loc] - tightness * (np.log(i - loc) - np.log(fpb)) ** 2
if score > best_score:
best_score = score
beat_location = loc
if beat_location >= 0:
cumscore[i] = score_i + best_score
else:
cumscore[i] = score_i
if first_beat and score_i < score_thresh:
backlink[i] = -1
else:
backlink[i] = beat_location
first_beat = False
local_max_mask = np.zeros(len(cumscore), dtype=bool)
local_max_mask[0] = False
for i in range(1, len(cumscore) - 1):
local_max_mask[i] = (cumscore[i] > cumscore[i-1]) and (cumscore[i] >= cumscore[i+1])
if len(cumscore) > 1:
local_max_mask[-1] = cumscore[-1] > cumscore[-2]
if np.any(local_max_mask):
median_max = np.median(cumscore[local_max_mask])
threshold = 0.5 * median_max
tail = -1
for i in range(len(cumscore) - 1, -1, -1):
if local_max_mask[i] and cumscore[i] >= threshold:
tail = i
break
else:
tail = len(cumscore) - 1
beats = np.zeros(len(localscore), dtype=bool)
n = tail
visited = set()
while n >= 0 and n not in visited:
beats[n] = True
visited.add(n)
n = backlink[n]
if trim and np.any(beats):
beat_positions = np.flatnonzero(beats)
beat_localscores = localscore[beat_positions]
w = np.hanning(5)
smooth_boe_full = np.convolve(beat_localscores, w)
smooth_boe = smooth_boe_full[len(w)//2 : len(localscore) + len(w)//2]
threshold = 0.5 * np.sqrt(np.mean(smooth_boe ** 2))
start_frame = 0
while start_frame < len(localscore) and localscore[start_frame] <= threshold:
beats[start_frame] = False
start_frame += 1
end_frame = len(localscore) - 1
while end_frame >= 0 and localscore[end_frame] <= threshold:
beats[end_frame] = False
end_frame -= 1
return np.flatnonzero(beats).astype(np.int32)
def compute_onset_envelope(mel_spec_db, n_fft=2048, hop_length=512):
"""Compute onset strength envelope from a log-mel spectrogram (dB)."""
lag = 1
onset_diff = mel_spec_db[:, lag:] - mel_spec_db[:, :-lag]
onset_diff = np.maximum(0.0, onset_diff)
envelope_pre_pad = np.mean(onset_diff, axis=0)
pad_width = lag + n_fft // (2 * hop_length)
envelope = np.pad(envelope_pre_pad, (pad_width, 0), mode='constant')
envelope = envelope[:mel_spec_db.shape[1]]
return envelope
def compute_mfcc(mel_spec_db, n_mfcc=20):
"""Compute MFCC features from a log-mel spectrogram (dB)."""
mfcc = scipy.fft.dct(mel_spec_db, axis=0, type=2, norm='ortho')[:n_mfcc].T
return mfcc.astype(np.float32)
def power_to_db(S, amin=1e-10, top_db=80.0, ref=1.0):
"""Convert a power spectrogram (amplitude squared) to decibel (dB) units"""
S = np.asarray(S)
log_spec = 10.0 * np.log10(np.maximum(amin, S))
log_spec -= 10.0 * np.log10(np.maximum(amin, ref))
if top_db is not None:
log_spec = np.maximum(log_spec, log_spec.max() - top_db)
return log_spec
class WanDancerEncodeAudio(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanDancerEncodeAudio",
category="conditioning/video_models",
inputs=[
io.Audio.Input("audio"),
io.Int.Input("video_frames", default=149, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Float.Input("audio_inject_scale", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="The scale for the audio features when injected into the video model."),
],
outputs=[
io.AudioEncoderOutput.Output(display_name="audio_encoder_output"),
io.String.Output(display_name="fps_string", tooltip="The calculated fps based on the audio length and the number of video frames. Used in the prompt."),
],
)
@classmethod
def execute(cls, video_frames, audio_inject_scale, audio) -> io.NodeOutput:
waveform = audio["waveform"][0]
sample_rate = audio["sample_rate"]
base_fps = 30
hop_length = 512
model_sr = 22050
n_fft = 2048
# start tempo from original audio (not the resampled one) to match the reference pipeline
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=False)
start_bpm = quick_tempo_estimate(waveform.squeeze().cpu().numpy(), sample_rate, hop_length=hop_length)
# resample to the sample rate used for feature extraction
resample_sr = base_fps * hop_length
waveform = torchaudio.functional.resample(waveform, sample_rate, resample_sr)
waveform_np = waveform.cpu().numpy().squeeze()
mel_spec = _compute_mel_spectrogram(waveform_np, model_sr, n_fft, hop_length, n_mels=128)
mel_spec_db = power_to_db(mel_spec, amin=1e-10, top_db=80.0, ref=1.0)
envelope = compute_onset_envelope(mel_spec_db, n_fft, hop_length)
mfcc = compute_mfcc(mel_spec_db, n_mfcc=20)
chroma = compute_chroma_cens(y=waveform_np, sr=model_sr, hop_length=hop_length).T
# detect peaks
peak_idxs = detect_onset_peaks(envelope, sr=model_sr, hop_length=hop_length)
peak_onehot = np.zeros_like(envelope, dtype=np.float32)
peak_onehot[peak_idxs] = 1.0
# detect beats
beat_tracking_tempo = estimate_tempo_from_onset(envelope, sr=model_sr, hop_length=hop_length, start_bpm=start_bpm)
beat_idxs = track_beats(envelope, beat_tracking_tempo, model_sr, hop_length, tightness=100, trim=True)
beat_onehot = np.zeros_like(envelope, dtype=np.float32)
beat_onehot[beat_idxs] = 1.0
audio_feature = np.concatenate(
[envelope[:, None], mfcc, chroma, peak_onehot[:, None], beat_onehot[:, None]],
axis=-1,
)
audio_feature = torch.from_numpy(audio_feature).unsqueeze(0).to(comfy.model_management.intermediate_device())
fps = float(base_fps / int(audio_feature.shape[1] / video_frames + 0.5))
audio_encoder_output = {
"audio_feature": audio_feature,
"fps": fps,
"audio_inject_scale": audio_inject_scale,
}
if int(fps + 0.5) != 30:
fps_string = " 帧率是{:.4f}".format(fps) # "frame rate is" in Chinese, as it was in the original pipeline
else:
fps_string = ", 帧率是30fps。" # to match the reference pipeline when the fps is 30
return io.NodeOutput(audio_encoder_output, fps_string)
class WanDancerVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanDancerVideo",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Int.Input("width", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=149, min=1, max=nodes.MAX_RESOLUTION, step=4, tooltip="The number of frames in the generated video. Should stay 149 for WanDancer."),
io.ClipVisionOutput.Input("clip_vision_output", optional=True, tooltip="The CLIP vision embeds for the first frame."),
io.ClipVisionOutput.Input("clip_vision_output_ref", optional=True, tooltip="The CLIP vision embeds for the reference image."),
io.Image.Input("start_image", optional=True, tooltip="The initial image(s) to be encoded, can be any number of frames."),
io.Mask.Input("mask", optional=True, tooltip="Image conditioning mask for the start image(s). White is kept, black is generated. Used for the local generations."),
io.AudioEncoderOutput.Input("audio_encoder_output", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent", tooltip="Empty latent."),
],
)
@classmethod
def execute(cls, positive, negative, vae, width, height, length, start_image=None, mask=None, clip_vision_output=None, clip_vision_output_ref=None, audio_encoder_output=None) -> io.NodeOutput:
latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
image = torch.zeros((length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
image[:start_image.shape[0]] = start_image
concat_latent_image = vae.encode(image[:, :, :, :3])
if mask is None:
concat_mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
concat_mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
else:
concat_mask = 1 - mask[:length].unsqueeze(0)
concat_mask = comfy.utils.common_upscale(concat_mask, concat_latent_image.shape[-2], concat_latent_image.shape[-1], "nearest-exact", "disabled")
concat_mask = torch.cat([torch.repeat_interleave(concat_mask[:, 0:1], repeats=4, dim=1), concat_mask[:, 1:]], dim=1)
concat_mask = concat_mask.view(1, concat_mask.shape[1] // 4, 4, concat_latent_image.shape[-2], concat_latent_image.shape[-1]).transpose(1, 2)
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": concat_mask})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": concat_mask})
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output, "clip_vision_output_ref": clip_vision_output_ref})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output, "clip_vision_output_ref": clip_vision_output_ref})
if audio_encoder_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"audio_embed": audio_encoder_output["audio_feature"], "fps": audio_encoder_output["fps"], "audio_inject_scale": audio_encoder_output.get("audio_inject_scale", 1.0)})
negative = node_helpers.conditioning_set_values(negative, {"audio_embed": audio_encoder_output["audio_feature"], "fps": audio_encoder_output["fps"], "audio_inject_scale": audio_encoder_output.get("audio_inject_scale", 1.0)})
out_latent = {}
out_latent["samples"] = latent
return io.NodeOutput(positive, negative, out_latent)
class WanDancerPadKeyframes(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanDancerPadKeyframes",
category="image/video",
inputs=[
io.Image.Input("images",),
io.Int.Input("segment_length", default=149, min=1, max=10000, tooltip="Length of this segment (usually 149 frames)"),
io.Int.Input("segment_index", default=0, min=0, max=100, tooltip="Which segment this is (0 for first, 1 for second, etc.)"),
io.Audio.Input("audio", tooltip="Audio to calculate total output frames from and extract segment audio."),
],
outputs=[
io.Image.Output(display_name="keyframes_sequence", tooltip="Padded keyframe sequence"),
io.Mask.Output(display_name="keyframes_mask", tooltip="Mask indicating valid frames"),
io.Audio.Output(display_name="audio_segment", tooltip="Audio segment for this video segment"),
],
)
@classmethod
def do_execute(cls, images, segment_length, segment_index, audio):
B, H, W, C = images.shape
fps = 30
# calculate total frames
audio_duration = audio["waveform"].shape[-1] / audio["sample_rate"]
segment_duration = segment_length / fps
buffer = 0.2
num_segments = int((audio_duration - buffer) / segment_duration) + 1 if audio_duration > buffer else 0
total_frames = num_segments * segment_length
mask = torch.zeros((segment_length, H, W), device=images.device, dtype=images.dtype)
keyframes = torch.zeros((segment_length, H, W, C), dtype=images.dtype, device=images.device)
# guard: with no audio or no images, nothing to place — leave keyframes/mask zeroed
if total_frames > 0 and B > 0:
frame_interval = float(total_frames) / B
seg_num = int(math.ceil(total_frames / segment_length))
is_last_segment = (segment_index == seg_num - 1)
positions = []
images_before_this_segment = 0
# count images consumed by previous segments
for seg_idx in range(segment_index):
end_idx = (total_frames - segment_length * seg_idx - 1) if seg_idx == seg_num - 1 else (segment_length - 1)
cnt = 0
while cnt * frame_interval < end_idx - frame_interval:
cnt += 1
images_before_this_segment += cnt
# positions for current segment
end_index = (total_frames - segment_length * segment_index - 1) if is_last_segment else (segment_length - 1)
cnt = 0
while cnt * frame_interval < end_index - frame_interval:
pos = int(math.ceil(frame_interval * cnt))
positions.append((pos, images_before_this_segment + cnt))
cnt += 1
positions.append((end_index, images_before_this_segment + cnt))
valid_positions = [(pos, idx) for pos, idx in positions if idx < B and pos < segment_length]
if valid_positions:
seg_positions, img_indices = zip(*valid_positions)
seg_positions = torch.tensor(seg_positions, dtype=torch.long, device=images.device)
img_indices = torch.tensor(img_indices, dtype=torch.long, device=images.device)
mask[seg_positions] = 1
keyframes[seg_positions] = images[img_indices]
# extract audio segment
segment_duration = segment_length / fps
start_time = segment_index * segment_duration
end_time = min(start_time + segment_duration, audio_duration)
sample_rate = audio["sample_rate"]
start_sample = int(start_time * sample_rate)
end_sample = int(end_time * sample_rate)
audio_segment_waveform = audio["waveform"][:, :, start_sample:end_sample]
audio_segment = {
"waveform": audio_segment_waveform,
"sample_rate": sample_rate
}
return keyframes, mask, audio_segment
@classmethod
def execute(cls, images, segment_length, segment_index, audio=None) -> io.NodeOutput:
return io.NodeOutput(*cls.do_execute(images, segment_length, segment_index, audio))
class WanDancerPadKeyframesList(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanDancerPadKeyframesList",
category="image/video",
inputs=[
io.Image.Input("images"),
io.Int.Input("segment_length", default=149, min=1, max=10000, tooltip="Length of each segment (usually 149 frames)"),
io.Int.Input("num_segments", default=1, min=1, max=100, tooltip="How many padded segments to emit as lists."),
io.Audio.Input("audio", tooltip="Audio to slice for each emitted segment."),
],
outputs=[
io.Image.Output(display_name="keyframes_sequence", tooltip="Padded keyframe sequences", is_output_list=True),
io.Mask.Output(display_name="keyframes_mask", tooltip="Masks indicating valid frames", is_output_list=True),
io.Audio.Output(display_name="audio_segment", tooltip="Audio segment for each video segment", is_output_list=True),
],
)
@classmethod
def execute(cls, images, segment_length, num_segments, audio=None) -> io.NodeOutput:
outputs = [WanDancerPadKeyframes.do_execute(images, segment_length, i, audio) for i in range(num_segments)]
keyframes, masks, audio_segments = zip(*outputs)
return io.NodeOutput(list(keyframes), list(masks), list(audio_segments))
class WanDancerExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
WanDancerVideo,
WanDancerEncodeAudio,
WanDancerPadKeyframes,
WanDancerPadKeyframesList,
]
async def comfy_entrypoint() -> WanDancerExtension:
return WanDancerExtension()

View File

@ -52,8 +52,6 @@ folder_names_and_paths["model_patches"] = ([os.path.join(models_dir, "model_patc
folder_names_and_paths["audio_encoders"] = ([os.path.join(models_dir, "audio_encoders")], supported_pt_extensions)
folder_names_and_paths["background_removal"] = ([os.path.join(models_dir, "background_removal")], supported_pt_extensions)
folder_names_and_paths["frame_interpolation"] = ([os.path.join(models_dir, "frame_interpolation")], supported_pt_extensions)
folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions)

View File

@ -2429,12 +2429,10 @@ async def init_builtin_extra_nodes():
"nodes_number_convert.py",
"nodes_painter.py",
"nodes_curve.py",
"nodes_bg_removal.py",
"nodes_rtdetr.py",
"nodes_frame_interpolation.py",
"nodes_sam3.py",
"nodes_void.py",
"nodes_wandancer.py",
]
import_failed = []

File diff suppressed because it is too large Load Diff

View File

@ -1,4 +1,4 @@
comfyui-frontend-package==1.43.18
comfyui-frontend-package==1.43.17
comfyui-workflow-templates==0.9.72
comfyui-embedded-docs==0.4.4
torch

View File

@ -124,11 +124,9 @@ class TestMathExpressionExecute:
with pytest.raises(Exception, match="not defined"):
self._exec("str(a)", a=42)
def test_boolean_result(self):
result = self._exec("a > b", a=5, b=3)
assert result[2] is True
result = self._exec("a > b", a=3, b=5)
assert result[2] is False
def test_boolean_result_raises(self):
with pytest.raises(ValueError, match="got bool"):
self._exec("a > b", a=5, b=3)
def test_empty_expression_raises(self):
with pytest.raises(ValueError, match="Expression cannot be empty"):